📄 runhnfa.py
字号:
#! /usr/bin/env python# -*- coding: iso-8859-1 -*-## Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander# Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.## This program is free software; you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation; either version 2, or (at your option)# any later version.## This program is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the# GNU General Public License (included in file License.txt in the# program package) for more details.#from bblocks.Label import Label, Unlabelimport bblocks.PickleHelpers as PickleHelpersimport bblocks.Helpers as Helpersimport hnfaimport sysimport osimport getoptimport signalimport reimport mathtry: import numpy.oldnumeric as Numericexcept: import Numericdatadir=os.environ.get('DATADIR', 'data')def usage(out=sys.stdout): out.write('Usage: python ' + os.path.basename(sys.argv[0]) + ' [OPTION]... [FILE]' + '\n') out.write('Makes a hnfanet using given data and runs it\n\n') out.write(' FILE File to save generated network in,\n') out.write(' defaults to filename generated from options\n\n') for option in options: out.write(" "); oname = reduce(lambda x,y:x+', '+y, option[0]) if len(oname) > 19 or len(option[4])>57: out.write(oname+'\n'+" "*10) else: out.write(oname.ljust(22)) out.write(option[4]+'\n') #[0]=(name,), [1]=param, [2]=tofilename, [3]=default_value, [4]=help options = [ (('-t', '--sourcetype',), str, -1, 'fa', "type of sources to use, '[d][i]fa' 'fa'(default))"), (('--ignorectrlc',), None,0, 0, 'if set runhnfa ignores SIGINT'), (('-n', '--numsources',), int, 2, None, 'number of sources to use'), (('-s', '--seed',), int, 3, 0, 'random seed to use'), (('-d','--datapoints',), int, 4, None, 'number of data points to use from the data set'), (('-l','--numpasses',), int, 6, 9, 'number of passes to do 0=linear mapping 9=default'), (('-p','--fileprefix',), str, 0, 'nl1', 'prefix to use in filenames and in data to load'), (('-i','--iterations',), int, 7, 5000, 'number of iterations to run'), (('-x','--onlyhj',), int, 8, 1000, 'number of iterations to use nothing but h-j in at end'), (('-r','--randhidden',), int, 9, 50, 'How many add randomly at first addition'), (('-y','--doprunehidden',), None,10,0, 'If set, prunes also hidden nodes'), (('--hiddentest',), int, 0, 1000, 'How many nodes to test at subsecuent addition of hidden nodes'), (('--hiddenaccept',), int, 0, 5, 'How many nodes to accept at subsecuent addition of hidden nodes'), (('--printfilename',), None,0, 0, 'prints filename which will be used for results and exit')]long_opts = ['help']short_opts = "h"optionsmap = {}maxtofilename = 0values = globals()for option in options: for name in option[0]: assert(len(name) >= 2) assert(name[0] == '-') if name[1] == '-': if option[1] is None: long_opts.append(name[2:]) else: long_opts.append(name[2:]+'=') else: assert(len(name) == 2) short_opts += name[1] if option[1] is not None: short_opts += ':' optionsmap[name] = option maxtofilename = max(maxtofilename, abs(option[2])) values[option[0][-1][2:]] = option[3]filenameopts = [None] * maxtofilenametry: opts, args = getopt.getopt(sys.argv[1:], short_opts, long_opts) for o, a in opts: if o in ('-h', '--help'): usage() sys.exit() else: try: option = optionsmap[o] except KeyError: raise getopt.GetoptError('Unknown option: ' + o,'') if option[1] is None: values[option[0][-1][2:]] = 1 else: try: values[option[0][-1][2:]] = option[1](a) except ValueError: raise getopt.GetoptError('In option ' + o + ' parameter ' + a + 'is not of type:' + option[1],'') if option[2] > 0: filenameopts[option[2]-1] = o[1] + a elif option[2] < 0: filenameopts[abs(option[2])-1] = a if len(args) == 1: forcefile = a elif len(args) > 1: raise getopt.GetoptError('Extra filenames given','') else: forcefile = Noneexcept getopt.GetoptError: sys.stderr.write(sys.argv[0] + ": " + str(sys.exc_info()[1]) + "\n") usage(sys.stderr) sys.exit(2)if ignorectrlc: oldsiginthandler = signal.signal(signal.SIGINT, signal.SIG_IGN) raisekbd = hnfa.Learner.SIGNOREelse: raisekbd = hnfa.Learner.SRAISE if forcefile is None: filename = os.path.split(fileprefix)[1] for x in filenameopts: if x is not None: filename += '_' + x filename += '.pickle.gz'else: filename = forcefileif printfilename: print re.sub('\.pickle(\.gz)?$','',filename) sys.exit(0) def load(filename): if os.path.isfile(filename) or os.path.isfile(filename+'.gz'): return PickleHelpers.load_compat(filename)data = Noneif data is None: data=load(fileprefix + '-data.pickle')if data is None: data=load(os.path.join(datadir, fileprefix + '-data.pickle'))if data is None: sys.stderr.write(sys.argv[0] + ': Datafile not found\n') sys.exit(2)if numsources is None: if os.path.split(fileprefix)[1] == 'helix': numsources = 1 else: numsources = min(8, data.shape[0])if datapoints is None: datapoints = data.shape[1]print sys.argvdef prune(): pruned = hnfanet.TryPruning() numhid = len(hnfanet.net.GetVariables('s\(1, ')) num={'Pruned': pruned} for x in ('s\(1, ','A\(2, 1, ','A\(1, 0, ','A\(2, 0, '): num[x.replace('\\','')[:-2]+')'] = len(hnfanet.net.GetVariables(x)) hnfanet.HistoryAdd("Nodes left", num) print numdef prunehidden(cdiff=500, num=1, lazy=0): pruned = hnfanet.TryPruneHidden(cdiff=cdiff,num=num,lazy=lazy) numhid = len(hnfanet.net.GetVariables('s\(1, ')) num={'HiddenPruned': pruned} for x in ('s\(1, ','A\(2, 1, ','A\(1, 0, ','A\(2, 0, '): num[x.replace('\\','')[:-2]+')'] = len(hnfanet.net.GetVariables(x)) hnfanet.HistoryAdd("Nodes left", num) print numdef iter(iters): hnfanet.LearnNet(printcost=10, iters=-iters, printhooke=1,hooke=10, raisekbd=raisekbd)def addw_10_prune_10(n,prob,doprunehidden=0,cdiff=500, num=1, lazy=0): if doprunehidden < 0: doprunehidden=n for j in range(n): hnfanet.AddAllWeights(prob) hnfanet.HistoryAdd('AddAllWeights', prob) prob *= probfactor iter(10) if doprunehidden and j%doprunehidden == doprunehidden/2: prunehidden(cdiff=cdiff,num=num,lazy=lazy) else: prune() iter(10) return probtry: hnfanet=hnfa.HnfaNet(data[:,:datapoints], numsources, seed, sourcetype=sourcetype) hnfanet.HistoryAdd("params", sys.argv) hnfanet.HistoryAdd("times", os.times()) hnfanet.LearnNet(printcost=10, iters=100, raisekbd=raisekbd) prob = .2 finalprob = .01 withaddweights=iterations-hnfanet.learner.iter-onlyhj-numpasses*40-20 if withaddweights > 0: probfactor = math.exp((math.log(finalprob)-math.log(prob))/ (withaddweights)/20) else: probfactor = 0 if numpasses > 0: hnfanet.AddHidden(num=randhidden, vsdecay = 500) iter(50) prune() iter(10) prob=addw_10_prune_10(7,prob) for i in range(1, numpasses): hnfanet.AddHiddenBest(num=hiddenaccept, numtest=hiddentest, vsdecay = 500) iter(30) prune() iter(10) prob=addw_10_prune_10(3,prob,-doprunehidden,lazy=10,num=3) prune() iter(10) prob=addw_10_prune_10((iterations-onlyhj-hnfanet.learner.iter)/20, prob,doprunehidden*5,cdiff=1000) hnfanet.LearnNet(printcost=10, iters=iterations, printhooke=1,hooke=10, raisekbd=raisekbd) hnfanet.HistoryAdd("times", os.times()) hnfanet.SaveWithPickle(filename)except KeyboardInterrupt: pass
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -