📄 hnfa.py
字号:
# -*- coding: iso-8859-1 -*-## Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander# Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.## This program is free software; you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation; either version 2, or (at your option)# any later version.## This program is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the# GNU General Public License (included in file License.txt in the# program package) for more details.#from __future__ import nested_scopes #needed for python2.1from bblocks.Label import Label, Unlabelimport bblocks.Net as Netimport bblocks.Helpers as Helpersimport bblocks.PyNet as PyNetimport bblocks.Learner as Learnerimport bblocks.PickleHelpers as PickleHelpersimport mathimport randomtry: import numpy.oldnumeric as Numeric import numpy as MLabexcept: import Numeric import MLabdef prunefunc(x): return apply(PyNet.PyNet.TryPruning, (x, Helpers.GetLabel(apply(PyNet.PyNet.GetVariables, (x, 'A'))),0, 0.01))class HnfaNet(PickleHelpers.Pickleable): def __init__(self, data, dim, seed=None, directconnect = -1, sourcetype="fa"): """ Generates an net. Parameters: TODO... """ self.data = Numeric.array(data) if type(dim) != type(()): self.dim = (len(self.data), 0, dim) elif len(dim) == 2: self.dim = (len(self.data),) + dim else: raise ValueError, "dim must be an integer or a tuple of length 2" if self.dim[0] < self.dim[2]: raise ValueError, "Input dimension can't be larger than output dimension:" + `dim` self.seed = seed self.random = random.Random(self.seed) self.directconnect = directconnect self.sourcetype = sourcetype self.net=PyNet.PyNet(len(self.data[0])) self.fact = PyNet.PyNodeFactory(self.net) const0 = self.fact.GetConst0() self.fact.GetConstant("const_1_2", -1) self.fact.GetConstant("c0", -30) self.fact.GetConstant("c1", -7) self.fact.GetConstant("c2", -4) self.fact.GetConstant("c3", -3) # Are those c[0-3] the correct ones? # Should the const_1_2 be kept? self.net.priorlist = { 'mvs(0)': ('const0', 'c1'), 'vvs(0)': ('const0', 'c2'), 'mvs(1)': ('const0', 'c1'), 'vvs(1)': ('const_1_2', 'c2'), 'mvA_in(1)': ('const0', 'c1'), 'vvA_in(1)': ('const0', 'c2'), 'mvA_out(1)': ('const0', 'c1'), 'vvA_out(1)': ('const0', 'c2'), 'mvA_in(0)': ('const0', 'c1'), 'vvA_in(0)': ('const0', 'c2')} # Using PCA get the most importat direction(s) to use as a priori data. pcomp = Helpers.DoPCA(self.data, self.dim[2]) pcompDV = Helpers.Array2DV(pcomp) s2 = self.MakeSources() for i in range(len(s2)): self.fact.EvidenceVNode(s2[i], mean=pcompDV[i], var=0.001, decay=40.0) dataDV = Helpers.Array2DV(self.data) for i in range(self.dim[0]): prod = [] for j in range(self.dim[2]): prod.append(self.AddOneWeight((2, 0, j ,i), lazy=1)) ms = self.fact.GetGaussian(Label("ms", 0, i), const0, const0) sum = self.fact.BuildSum2VTree(prod + [ms], ("sums", 0, i)) s = self.fact.GetGaussianV(Label("s", 0, i), sum, self.net.GetGaussianNode(Label("vs", 0, i))) s.Clamp(dataDV[i]) self.nexthidden = 0 self.learner = Learner.Learner(self.net, prunefunc=prunefunc) self.learner.HistoryAdd("Begin") self.AddHidden(self.dim[1]) self.net.SortNodes() def MakeSources(self): if self.sourcetype not in ('difa', 'ifa', 'fa', 'dfa'): raise RuntimeError('Unknown sourcetype: ' + self.sourcetype) ms2 = self.fact.GetConst0() if self.sourcetype in ('fa', 'dfa'): vs2 = self.fact.GetConst0() else: mmvs2 = self.fact.GetConst0() vmvs2 = self.fact.GetConst0() mvs2 = self.fact.MakeNodes('Gaussian', 'mvs(2)', self.dim[2], mmvs2, vmvs2) for i in range(len(mvs2)): self.fact.EvidenceNode(mvs2[i], mean=0, var=1, decay=1) mvvs2 = self.fact.GetConst0() vvvs2 = self.fact.GetConst0() vvs2 = self.fact.MakeNodes('Gaussian', 'vvs(2)', self.dim[2], mvvs2, vvvs2) for i in range(len(vvs2)): self.fact.EvidenceNode(vvs2[i], mean=0, var=1, decay=1) vs2 = self.fact.MakeNodes('GaussianV', 'vs(2)', self.dim[2], mvs2, vvs2) for i in range(len(vs2)): self.fact.EvidenceVNode(vs2[i], mean=0, var=1, decay=1) if self.sourcetype in ('ifa', 'fa'): s2 = self.fact.MakeNodes('GaussianV', 's(2)', self.dim[2], ms2, vs2) else: mas2 = self.fact.GetConst0() vas2 = self.fact.GetConst0() as2 = self.fact.MakeNodes('Gaussian', 'as(2)', self.dim[2], mas2, vas2) for i in range(len(as2)): self.fact.EvidenceNode(as2[i], mean=0, var=1, decay=1) m0s2 = self.fact.GetConst0() mv0s2 = self.fact.GetConst0() vv0s2 = self.fact.GetConst0() v0s2 = self.fact.GetGaussian('v0s(2)', mv0s2, vv0s2) self.fact.EvidenceNode(v0s2, mean=0, var=1, decay=1) s2= self.fact.MakeNodes('DelayGaussV', 's(2)', self.dim[2], ms2, vs2, as2, m0s2, v0s2) return s2 def SaveMeanV(self, outfile, nodes='s\(2'): if type(nodes) == type(''): nodes = self.net.GetNodes(nodes) f=open(outfile,'w') for x in Numeric.transpose(Numeric.array(Helpers.GetMeanV(nodes))): f.write(reduce(lambda x,y:x+" "+`y`,x,"")) f.write('\n') f.close() def HistoryAdd(self, record, data = ()): self.learner.HistoryAdd(record, data) def AddHidden(self, num=1, parentprob=1.0, childprob=1.0, scalemean = 0.5, scalestd = 0.5, vsdecay = 1000): if num == 0: return weights=[] childind = [] bias = [] for j in range(num): scale = math.exp(self.random.gauss(scalemean, scalestd)) parentind = [] for i in range(self.dim[2]): if self.random.random() < parentprob: parentind.append(1) else: parentind.append(0) childind.append([]) for i in range(self.dim[0]): if self.random.random() < childprob: childind[j].append(i) weights.append([]) for i in range(len(parentind)): if parentind[1]: weights[j].append(self.random.gauss( 0, scale/math.sqrt(len(parentind)))) else: weights[j].append(None) bias.append(self.random.gauss(0, scale)) self.AddHidden2(weights, bias, childind, vsdecay) def AddHidden2(self, weights, bias, childind=None, vsdecay = 1000): if len(weights) == 0: return if childind is None: childind = [range(self.dim[0])]*len(weights) sumcache = [] Alist = [] for i in range(self.dim[0]): sumcache.append([]) for j in range(len(weights)): vA_in = self.net.GetGaussianNode(Label("vA_in", 1, self.nexthidden))
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -