📄 util.py
字号:
# -*- coding: iso-8859-1 -*-## Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander# Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.## This program is free software; you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation; either version 2, or (at your option)# any later version.## This program is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the# GNU General Public License (included in file License.txt in the# program package) for more details.## $Id: util.py,v 1.1.1.1 2006/11/23 09:42:07 mha Exp $#"""Utility functions."""import random, re, string, os, mathimport Numeric, RandomArray, MLab, LinearAlgebrafrom bblocks import Helpers, Learnerimport blocksdef mse(s, shat): return Numeric.sum((s - shat)**2, -1) / s.shape[-1]def prunesources(net, label, treshold): def costdiff(node): if node: return Helpers.CostDifferenceV(node) else: return 10**6 print "foo" cdiff = map(costdiff, net.GetVariableArray(label)) print "cdiff =", cdiff while min(cdiff) < treshold: sn = net.GetVariableArray(label) i = Numeric.argmin(cdiff) print "----" print "s(%d), cdiff = %f" % (i, min(cdiff)) print "Cost before = %f" % (net.Cost()) sn[i].Die() net.CleanUp() for j in range(5): net.UpdateAll() print "%d : %f" % (j, net.Cost()) print "Cost after = %f" % (net.Cost()) cdiff = map(costdiff, net.GetVariableArray(label))def grep(rexp, lst): """Returns strings matching rexp or having a substring that matches.""" com = re.compile(rexp) return filter(lambda str: com.search(str), lst)def which_interval(intervals, x): """Returns the index of the interval to which x belongs.""" for i in range(len(intervals)): if x < intervals[i]: return i # in none of the intervals? too bad... assert 0def mixgauss(means, stdevs, weights, shape): """Returns a random array from a mixture-of-gaussians distribution.""" intervals = Numeric.cumsum(map(lambda w: 1.0*w / Numeric.sum(weights), weights)) def iter(shape): if len(shape) > 0: lst = [] for i in range(shape[0]): lst.append(iter(shape[1:])) return lst else: choice = which_interval(intervals, random.random()) return random.gauss(means[choice], stdevs[choice]) return Numeric.array(iter(shape), Numeric.Float)def laplace(std, shape): """Returns a random array of given shape from Laplace distribution.""" sign = 2*(RandomArray.random(shape) < 0.5) - 1 expd = RandomArray.exponential(1.0, shape) return std*sign*expd / sqrt(2.0)def mlab_data(file, var): """Loads a matlab array from a mat-file.""" return Numeric.array(Helpers.LoadMatlabArray(file, var), Numeric.Float)def save_array_as_dat(ary, filename): """Saves array into a file in ascii-format. """ assert len(ary.shape) == 2, "not a 2D-array" file = open(filename, "w") for i in range(ary.shape[0]): for j in range(ary.shape[1]): file.write(str(ary[i,j]) + " ") file.write("\n") file.close()def read_dat(filename, separator = None): "Reads an array from a dat-file (e.g. mat-files in ascii format).""" cells = read_dat_cells(filename, separator) floats = convert(cells, float) return Numeric.array(floats, Numeric.Float)def read_dat_cells(filename, separator = None): """Reads data from file where columns are separated by whitespaces. """ data = [] file = open(filename, "r") for line in file.readlines(): data.append(string.split(line, separator)) file.close() return datadef convert(data, convfunc = float): def convcell(cell, convfunc = float): try: return convfunc(cell) except ValueError: return None return map(lambda row: map(convcell, row), data)def is_vector(node): return node.GetType()[-1] == 'V'def is_sequence(obj): """Checks if an object is a sequence such as list or DV.""" methods = dir(obj) return "__len__" in methods or "tolist" in methodsdef equalize(arr): """Equalises a 2D array (list of lists) (we need this due to the difficulties caused by dead nodes)""" if not is_sequence(arr): return l = -1 for i in range(len(arr)): if is_sequence(arr[i]): l = len(arr[i]) break if l != -1: for i in range(len(arr)): if not is_sequence(arr[i]): arr[i] = [0] * l return arrdef getmean(thunk): """Gets the mean values and returns them as a numeric array works with a single node, list of nodes, list of list of nodes etc.""" def iter(thunk): if thunk == 0 or thunk is None: return 0 elif is_sequence(thunk): return map(iter, thunk) elif thunk.GetDying() == 1: return 0 elif is_vector(thunk): return Helpers.GetMeanV(thunk) else: return Helpers.GetMean(thunk) arr = iter(thunk) equalize(arr) return Numeric.array(arr, Numeric.Float)def replace_parent(node, oldpar, newpar): """Replaces a nodes parent with a new node.""" node.ReplacePtr(oldpar, newpar) newpar.AddChild(node) oldpar.NotifyDeath(node)def lowpass(s): """Three point moving average.""" sleft = Numeric.zeros(s.shape[0], Numeric.Float) sleft[:-1] = s[1:] sright = Numeric.zeros(s.shape[0], Numeric.Float) sright[1:] = s[:-1] y = (sleft + s + sright) / 3 y[0] = 1.5 * y[0] y[-1] = 1.5 * y[-1] return ydef smooth(data, n = 1): """Lowpass filter.""" if len(data.shape) > 1: for i in range(data.shape[0]): return Numeric.array(map(lambda seq: smooth(seq, n), data), Numeric.Float) else: for m in range(n): x = lowpass(data) data = x return xdef getattrrec(obj, attr): i = string.find(attr, ".") if i == -1: return getattr(obj, attr) else: return getattrrec(getattr(obj, attr[:i]), attr[i+1:])def getdata(model, vars): data = {} for var in vars: data[var] = getmean(getattrrec(model, var)) return datadef pair(arr1, arr2): assert arr1.shape == arr2.shape r, c = arr1.shape paired = zeros((2*r, c), Float) for i in range(arr1.shape[0]): paired[2*i,:] = arr1[i,:] paired[2*i+1,:] = arr2[i,:] return paireddef normalize(data): """Normalises data by making it zero-mean and unit-variance. """ uzdata = ( data - MLab.mean(data, -1)[:,Numeric.NewAxis]) / MLab.std( data, -1)[:,Numeric.NewAxis] return uzdatadef orthogonalize(basis): """Orthogonalizes the column vectors of a matrix.""" def norm(v): return Numeric.sqrt(Numeric.sum(v**2)) q = Numeric.zeros(basis.shape, Numeric.Float) q[:,0] = basis[:,0] / norm(basis[:,0]) for j in range(1, basis.shape[1]): vcol = basis[:,j:j+1] vrow = basis[:,j] u = vrow - Numeric.sum(Numeric.sum(vcol * q[:,:j]) * q[:,:j], -1) q[:,j] = u / norm(u) return qdef filters(X, S): mmul = Numeric.matrixmultiply inv = LinearAlgebra.inverse trans = Numeric.transpose return mmul(mmul(inv(mmul(X, trans(X))), X), trans(S))def filters2(A): """Works if sources are uncorrelated and their variances are equal.""" mmul = Numeric.matrixmultiply inv = LinearAlgebra.inverse trans = Numeric.transpose return mmul(inv(mmul(A, trans(A))), A)def do_model(config, data): import osvar config["xdim"] = data.shape[0] config["tdim"] = data.shape[1] model = osvar.Model(config) model.clamp(data) l = Learner.Learner(model.net) l.AddCall("addsources", model.addsources, config["srclayer"]) l.AddCall("addvarsources", model.addvarsources, config["varlayer"]) iters = config["iters"] l.AddCall("prunesmap", lambda: model.sblk.smap.prune(config["prunecost"]), range(config["prune"], iters, config["prunefreq"])) l.AddCall("addweightstosmap", lambda: model.sblk.smap.addweights(config["addcost"]), range(config["add"], iters, config["addfreq"])) if config["rblock"] is not None: l.AddCall("prunermap", lambda: model.rblk.smap.prune(config["prunecost"]), range(config["prune"] + 5, iters, config["prunefreq"])) l.AddCall("addweightstormap", lambda: model.rblk.smap.addweights(config["addcost"]), range(config["add"] + 5, iters, config["addfreq"])) if config["rotater"] is not None: when = config["rotater"] if not is_sequence(when): assert type(when) == type(100) when = (when,) l.AddCall("rotatevarsources", lambda: model.rblk.rotatesources(100), when) l.LearnNet(iters = config["iters"], printcost = config["printcost"]) model.final_cost = l.history[-1][2] model.history = l.history return modeldef dampsum(ary, a = 0.99): assert len(ary.shape) == 2 sum = Numeric.zeros(ary.shape, Numeric.Float) sum[:,0] = ary[:,0] for i in range(1, ary.shape[1]): sum[:,i] = sum[:,i-1]*a + ary[:,i] return sumdef get_allvars(net): all = {} possible = {"rblk_s": "r", "rblk_A_a": "B", "sblk_s": "s", "sblk_A_a": "A", "sblk_u": "us", "x": "x", "ux": "ux"} for key in possible: try: ary = getmean(net.GetVariableArray(key)) all[possible[key]] = ary except Exception: pass return alldef get_costs(history): return map(lambda ent: ent[2], filter(lambda ent: (ent[0] == "Iteration" or ent[0] == "HookeJeeves"), history))def get_prunes(history): return filter(lambda ent: ent[0][:5] == "prune", history)def get_adds(history): return filter(lambda ent: ent[0][:3] == "add", history)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -