⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 util.py

📁 Extension packages to Bayes Blocks library, part 1
💻 PY
字号:
# -*- coding: iso-8859-1 -*-## Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander# Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.## This program is free software; you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation; either version 2, or (at your option)# any later version.## This program is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the# GNU General Public License (included in file License.txt in the# program package) for more details.## $Id: util.py,v 1.1.1.1 2006/11/23 09:42:07 mha Exp $#"""Utility functions."""import random, re, string, os, mathimport Numeric, RandomArray, MLab, LinearAlgebrafrom bblocks import Helpers, Learnerimport blocksdef mse(s, shat):    return Numeric.sum((s - shat)**2, -1) / s.shape[-1]def prunesources(net, label, treshold):    def costdiff(node):        if node:            return Helpers.CostDifferenceV(node)        else:            return 10**6    print "foo"    cdiff = map(costdiff, net.GetVariableArray(label))    print "cdiff =", cdiff    while min(cdiff) < treshold:        sn = net.GetVariableArray(label)        i = Numeric.argmin(cdiff)        print "----"        print "s(%d), cdiff = %f" % (i, min(cdiff))        print "Cost before = %f" % (net.Cost())        sn[i].Die()        net.CleanUp()        for j in range(5):            net.UpdateAll()            print "%d : %f" % (j, net.Cost())        print "Cost after = %f" % (net.Cost())        cdiff = map(costdiff, net.GetVariableArray(label))def grep(rexp, lst):    """Returns strings matching rexp or having a substring that matches."""    com = re.compile(rexp)    return filter(lambda str: com.search(str), lst)def which_interval(intervals, x):    """Returns the index of the interval to which x belongs."""    for i in range(len(intervals)):        if x < intervals[i]:            return i    # in none of the intervals? too bad...    assert 0def mixgauss(means, stdevs, weights, shape):    """Returns a random array from a mixture-of-gaussians distribution."""    intervals = Numeric.cumsum(map(lambda w: 1.0*w /                                   Numeric.sum(weights), weights))    def iter(shape):        if len(shape) > 0:            lst = []            for i in range(shape[0]):                lst.append(iter(shape[1:]))            return lst        else:            choice = which_interval(intervals, random.random())            return random.gauss(means[choice], stdevs[choice])    return Numeric.array(iter(shape), Numeric.Float)def laplace(std, shape):    """Returns a random array of given shape from Laplace distribution."""    sign = 2*(RandomArray.random(shape) < 0.5) - 1    expd = RandomArray.exponential(1.0, shape)    return std*sign*expd / sqrt(2.0)def mlab_data(file, var):    """Loads a matlab array from a mat-file."""    return Numeric.array(Helpers.LoadMatlabArray(file, var), Numeric.Float)def save_array_as_dat(ary, filename):    """Saves array into a file in ascii-format.    """    assert len(ary.shape) == 2, "not a 2D-array"    file = open(filename, "w")    for i in range(ary.shape[0]):        for j in range(ary.shape[1]):            file.write(str(ary[i,j]) + " ")        file.write("\n")    file.close()def read_dat(filename, separator = None):    "Reads an array from a dat-file (e.g. mat-files in ascii format)."""    cells = read_dat_cells(filename, separator)    floats = convert(cells, float)    return Numeric.array(floats, Numeric.Float)def read_dat_cells(filename, separator = None):    """Reads data from file where columns are separated by whitespaces.    """    data = []    file = open(filename, "r")    for line in file.readlines():        data.append(string.split(line, separator))    file.close()    return datadef convert(data, convfunc = float):    def convcell(cell, convfunc = float):        try:            return convfunc(cell)        except ValueError:            return None    return map(lambda row: map(convcell, row), data)def is_vector(node):    return node.GetType()[-1] == 'V'def is_sequence(obj):    """Checks if an object is a sequence such as list or DV."""    methods = dir(obj)    return "__len__" in methods or "tolist" in methodsdef equalize(arr):    """Equalises a 2D array (list of lists)    (we need this due to the difficulties caused by dead nodes)"""    if not is_sequence(arr):        return    l = -1    for i in range(len(arr)):        if is_sequence(arr[i]):            l = len(arr[i])            break    if l != -1:        for i in range(len(arr)):            if not is_sequence(arr[i]):                arr[i] = [0] * l    return arrdef getmean(thunk):    """Gets the mean values and returns them as a numeric array    works with a single node, list of nodes, list of list of nodes etc."""    def iter(thunk):        if thunk == 0 or thunk is None:            return 0        elif is_sequence(thunk):            return map(iter, thunk)        elif thunk.GetDying() == 1:            return 0        elif is_vector(thunk):            return Helpers.GetMeanV(thunk)        else:            return Helpers.GetMean(thunk)    arr = iter(thunk)    equalize(arr)    return Numeric.array(arr, Numeric.Float)def replace_parent(node, oldpar, newpar):    """Replaces a nodes parent with a new node."""    node.ReplacePtr(oldpar, newpar)    newpar.AddChild(node)    oldpar.NotifyDeath(node)def lowpass(s):    """Three point moving average."""    sleft = Numeric.zeros(s.shape[0], Numeric.Float)    sleft[:-1] = s[1:]    sright = Numeric.zeros(s.shape[0], Numeric.Float)    sright[1:] = s[:-1]    y = (sleft + s + sright) / 3    y[0] = 1.5 * y[0]    y[-1] = 1.5 * y[-1]    return ydef smooth(data, n = 1):    """Lowpass filter."""    if len(data.shape) > 1:        for i in range(data.shape[0]):            return Numeric.array(map(lambda seq: smooth(seq, n), data),                                 Numeric.Float)    else:        for m in range(n):            x = lowpass(data)            data = x        return xdef getattrrec(obj, attr):    i = string.find(attr, ".")    if i == -1:        return getattr(obj, attr)    else:        return getattrrec(getattr(obj, attr[:i]), attr[i+1:])def getdata(model, vars):    data = {}    for var in vars:        data[var] = getmean(getattrrec(model, var))    return datadef pair(arr1, arr2):    assert arr1.shape == arr2.shape    r, c = arr1.shape        paired = zeros((2*r, c), Float)    for i in range(arr1.shape[0]):        paired[2*i,:] = arr1[i,:]        paired[2*i+1,:] = arr2[i,:]    return paireddef normalize(data):    """Normalises data by making it zero-mean and unit-variance.    """    uzdata = (        data - MLab.mean(data, -1)[:,Numeric.NewAxis]) / MLab.std(        data, -1)[:,Numeric.NewAxis]    return uzdatadef orthogonalize(basis):    """Orthogonalizes the column vectors of a matrix."""    def norm(v):        return Numeric.sqrt(Numeric.sum(v**2))    q = Numeric.zeros(basis.shape, Numeric.Float)    q[:,0] = basis[:,0] / norm(basis[:,0])    for j in range(1, basis.shape[1]):        vcol = basis[:,j:j+1]        vrow = basis[:,j]        u = vrow - Numeric.sum(Numeric.sum(vcol * q[:,:j]) * q[:,:j], -1)        q[:,j] = u / norm(u)    return qdef filters(X, S):    mmul = Numeric.matrixmultiply    inv = LinearAlgebra.inverse    trans = Numeric.transpose    return mmul(mmul(inv(mmul(X, trans(X))), X), trans(S))def filters2(A):    """Works if sources are uncorrelated and their variances are equal."""    mmul = Numeric.matrixmultiply    inv = LinearAlgebra.inverse    trans = Numeric.transpose    return mmul(inv(mmul(A, trans(A))), A)def do_model(config, data):    import osvar    config["xdim"] = data.shape[0]    config["tdim"] = data.shape[1]    model = osvar.Model(config)    model.clamp(data)    l = Learner.Learner(model.net)    l.AddCall("addsources", model.addsources, config["srclayer"])    l.AddCall("addvarsources", model.addvarsources, config["varlayer"])    iters = config["iters"]    l.AddCall("prunesmap",              lambda: model.sblk.smap.prune(config["prunecost"]),              range(config["prune"], iters, config["prunefreq"]))    l.AddCall("addweightstosmap",              lambda: model.sblk.smap.addweights(config["addcost"]),              range(config["add"], iters, config["addfreq"]))    if config["rblock"] is not None:        l.AddCall("prunermap",                  lambda: model.rblk.smap.prune(config["prunecost"]),                  range(config["prune"] + 5, iters, config["prunefreq"]))        l.AddCall("addweightstormap",                  lambda: model.rblk.smap.addweights(config["addcost"]),                  range(config["add"] + 5, iters, config["addfreq"]))    if config["rotater"] is not None:        when = config["rotater"]        if not is_sequence(when):            assert type(when) == type(100)            when = (when,)        l.AddCall("rotatevarsources", lambda: model.rblk.rotatesources(100),                  when)    l.LearnNet(iters = config["iters"], printcost = config["printcost"])    model.final_cost = l.history[-1][2]    model.history = l.history    return modeldef dampsum(ary, a = 0.99):    assert len(ary.shape) == 2    sum = Numeric.zeros(ary.shape, Numeric.Float)    sum[:,0] = ary[:,0]    for i in range(1, ary.shape[1]):        sum[:,i] = sum[:,i-1]*a + ary[:,i]    return sumdef get_allvars(net):    all = {}    possible = {"rblk_s": "r", "rblk_A_a": "B", "sblk_s": "s", "sblk_A_a": "A",                "sblk_u": "us", "x": "x", "ux": "ux"}    for key in possible:        try:            ary = getmean(net.GetVariableArray(key))            all[possible[key]] = ary        except Exception:            pass    return alldef get_costs(history):    return map(lambda ent: ent[2],               filter(lambda ent: (ent[0] == "Iteration"                                   or ent[0] == "HookeJeeves"), history))def get_prunes(history):    return filter(lambda ent: ent[0][:5] == "prune", history)def get_adds(history):    return filter(lambda ent: ent[0][:3] == "add", history)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -