⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 runhnfa.py

📁 Extension packages to Bayes Blocks library, reported in "Nonlinear independent factor analysis by hi
💻 PY
字号:
#! /usr/bin/env python# -*- coding: iso-8859-1 -*-## Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander# Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.## This program is free software; you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation; either version 2, or (at your option)# any later version.## This program is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the# GNU General Public License (included in file License.txt in the# program package) for more details.#from bblocks.Label import Label, Unlabelimport bblocks.PickleHelpers as PickleHelpersimport bblocks.Helpers as Helpersimport hnfaimport sysimport osimport getoptimport signalimport reimport mathtry:    import numpy.oldnumeric as Numericexcept:    import Numericdatadir=os.environ.get('DATADIR', 'data')def usage(out=sys.stdout):    out.write('Usage: python ' + os.path.basename(sys.argv[0]) +              ' [OPTION]... [FILE]' + '\n')    out.write('Makes a hnfanet using given data and runs it\n\n')    out.write('  FILE   File to save generated network in,\n')    out.write('         defaults to filename generated from options\n\n')    for option in options:        out.write("  ");        oname = reduce(lambda x,y:x+', '+y, option[0])        if len(oname) > 19 or len(option[4])>57:            out.write(oname+'\n'+" "*10)        else:            out.write(oname.ljust(22))        out.write(option[4]+'\n')            #[0]=(name,), [1]=param, [2]=tofilename, [3]=default_value, [4]=help options = [    (('-t', '--sourcetype',), str, -1, 'fa',     "type of sources to use, '[d][i]fa' 'fa'(default))"),    (('--ignorectrlc',),      None,0, 0,     'if set runhnfa ignores SIGINT'),    (('-n', '--numsources',), int, 2, None,     'number of sources to use'),    (('-s', '--seed',),       int, 3, 0,     'random seed to use'),    (('-d','--datapoints',),  int, 4, None,     'number of data points to use from the data set'),    (('-l','--numpasses',),   int, 6, 9,     'number of passes to do 0=linear mapping 9=default'),    (('-p','--fileprefix',),  str, 0, 'nl1',     'prefix to use in filenames and in data to load'),    (('-i','--iterations',),  int, 7, 5000,     'number of iterations to run'),    (('-x','--onlyhj',),      int, 8, 1000,     'number of iterations to use nothing but h-j in at end'),    (('-r','--randhidden',),  int, 9, 50,     'How many add randomly at first addition'),    (('-y','--doprunehidden',), None,10,0,     'If set, prunes also hidden nodes'),    (('--hiddentest',),       int, 0, 1000,     'How many nodes to test at subsecuent addition of hidden nodes'),    (('--hiddenaccept',),     int, 0, 5,     'How many nodes to accept at subsecuent addition of hidden nodes'),    (('--printfilename',),    None,0, 0,     'prints filename which will be used for results and exit')]long_opts = ['help']short_opts = "h"optionsmap = {}maxtofilename = 0values = globals()for option in options:    for name in option[0]:        assert(len(name) >= 2)        assert(name[0] == '-')        if name[1] == '-':            if option[1] is None:                long_opts.append(name[2:])            else:                long_opts.append(name[2:]+'=')        else:            assert(len(name) == 2)            short_opts += name[1]            if option[1] is not None:                short_opts += ':'        optionsmap[name] = option        maxtofilename = max(maxtofilename, abs(option[2]))    values[option[0][-1][2:]] = option[3]filenameopts = [None] * maxtofilenametry:    opts, args = getopt.getopt(sys.argv[1:], short_opts, long_opts)    for o, a in opts:        if o in ('-h', '--help'):            usage()            sys.exit()        else:            try:                option = optionsmap[o]            except KeyError:                raise getopt.GetoptError('Unknown option: ' + o,'')            if option[1] is None:                values[option[0][-1][2:]] = 1            else:                try:                    values[option[0][-1][2:]] = option[1](a)                except ValueError:                    raise getopt.GetoptError('In option ' + o + ' parameter ' +                                             a + 'is not of type:' +                                             option[1],'')             if option[2] > 0:                filenameopts[option[2]-1] = o[1] + a            elif option[2] < 0:                filenameopts[abs(option[2])-1] = a    if len(args) == 1:        forcefile = a    elif len(args) > 1:        raise getopt.GetoptError('Extra filenames given','')    else:        forcefile = Noneexcept getopt.GetoptError:    sys.stderr.write(sys.argv[0] + ": " + str(sys.exc_info()[1]) + "\n")    usage(sys.stderr)    sys.exit(2)if ignorectrlc:    oldsiginthandler = signal.signal(signal.SIGINT, signal.SIG_IGN)    raisekbd = hnfa.Learner.SIGNOREelse:    raisekbd = hnfa.Learner.SRAISE    if forcefile is None:    filename = os.path.split(fileprefix)[1]    for x in filenameopts:        if x is not None:            filename += '_' + x    filename += '.pickle.gz'else:    filename = forcefileif printfilename:    print re.sub('\.pickle(\.gz)?$','',filename)    sys.exit(0)        def load(filename):    if os.path.isfile(filename) or os.path.isfile(filename+'.gz'):        return PickleHelpers.load_compat(filename)data = Noneif data is None: data=load(fileprefix + '-data.pickle')if data is None: data=load(os.path.join(datadir, fileprefix + '-data.pickle'))if data is None:    sys.stderr.write(sys.argv[0] + ': Datafile not found\n')    sys.exit(2)if numsources is None:    if os.path.split(fileprefix)[1] == 'helix':        numsources = 1    else:        numsources = min(8, data.shape[0])if datapoints is None:    datapoints = data.shape[1]print sys.argvdef prune():    pruned = hnfanet.TryPruning()    numhid = len(hnfanet.net.GetVariables('s\(1, '))    num={'Pruned': pruned}    for x in ('s\(1, ','A\(2, 1, ','A\(1, 0, ','A\(2, 0, '):        num[x.replace('\\','')[:-2]+')'] = len(hnfanet.net.GetVariables(x))        hnfanet.HistoryAdd("Nodes left", num)    print numdef prunehidden(cdiff=500, num=1, lazy=0):    pruned = hnfanet.TryPruneHidden(cdiff=cdiff,num=num,lazy=lazy)    numhid = len(hnfanet.net.GetVariables('s\(1, '))    num={'HiddenPruned': pruned}    for x in ('s\(1, ','A\(2, 1, ','A\(1, 0, ','A\(2, 0, '):        num[x.replace('\\','')[:-2]+')'] = len(hnfanet.net.GetVariables(x))        hnfanet.HistoryAdd("Nodes left", num)    print numdef iter(iters):    hnfanet.LearnNet(printcost=10, iters=-iters,                     printhooke=1,hooke=10, raisekbd=raisekbd)def addw_10_prune_10(n,prob,doprunehidden=0,cdiff=500, num=1, lazy=0):    if doprunehidden < 0:        doprunehidden=n    for j in range(n):        hnfanet.AddAllWeights(prob)        hnfanet.HistoryAdd('AddAllWeights', prob)        prob *= probfactor        iter(10)        if doprunehidden and j%doprunehidden == doprunehidden/2:            prunehidden(cdiff=cdiff,num=num,lazy=lazy)        else:            prune()        iter(10)    return probtry:    hnfanet=hnfa.HnfaNet(data[:,:datapoints], numsources, seed,                         sourcetype=sourcetype)    hnfanet.HistoryAdd("params", sys.argv)    hnfanet.HistoryAdd("times", os.times())    hnfanet.LearnNet(printcost=10, iters=100, raisekbd=raisekbd)    prob = .2    finalprob = .01    withaddweights=iterations-hnfanet.learner.iter-onlyhj-numpasses*40-20    if withaddweights > 0:        probfactor = math.exp((math.log(finalprob)-math.log(prob))/                              (withaddweights)/20)    else:        probfactor = 0    if numpasses > 0:        hnfanet.AddHidden(num=randhidden, vsdecay = 500)        iter(50)        prune()        iter(10)        prob=addw_10_prune_10(7,prob)    for i in range(1, numpasses):        hnfanet.AddHiddenBest(num=hiddenaccept, numtest=hiddentest,                              vsdecay = 500)        iter(30)        prune()        iter(10)        prob=addw_10_prune_10(3,prob,-doprunehidden,lazy=10,num=3)    prune()    iter(10)    prob=addw_10_prune_10((iterations-onlyhj-hnfanet.learner.iter)/20,                          prob,doprunehidden*5,cdiff=1000)    hnfanet.LearnNet(printcost=10, iters=iterations,                     printhooke=1,hooke=10, raisekbd=raisekbd)            hnfanet.HistoryAdd("times", os.times())    hnfanet.SaveWithPickle(filename)except KeyboardInterrupt:    pass

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -