⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 learner.py

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 PY
字号:
# -*- coding: iso-8859-1 -*-## This file is a part of the Bayes Blocks library## Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander# Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.## This program is free software; you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation; either version 2, or (at your option)# any later version.## This program is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the# GNU General Public License (included in file License.txt in the# program package) for more details.## $Id: Learner.py 5 2006-10-26 09:44:54Z ah $#import Helpersimport timeimport signalsigintencountered = 0def siginthandler(signum, frame):    global sigintencountered    sigintencountered = 1SIGNORE = -1SSTOP = 0SRAISE = 1class timewrap:    def __init__(self):        self.lastclock = time.clock()        self.wraps = 0    def clock(self):        newtime = time.clock()        if self.lastclock > newtime:            self.wraps += 1        self.lastclock = newtime        # math.exp(32*math.log(2)-6*math.log(10)) == 4294.9672959999916        return newtime + self.wraps*4294.9672959999916class Learner:    def __getstate__(self):        odict = self.__dict__.copy()        del odict['costfun']        odict['costfunname'] = self.costfun.im_func.func_name        odict['costfunself'] = self.costfun.im_self        return odict    def __setstate__(self, dict):        if dict.has_key('costfunname') and dict.has_key('costfunself'):            costfun = eval("dict['costfunself']" + "." + dict['costfunname'])            del dict['costfunname']            del dict['costfunself']            self.__dict__.update(dict)             self.costfun = costfun        else:            self.__dict__.update(dict)             self.costfun = self.net.Cost    def __init__(self, net, prunefunc=None):        self.net = net        self.printhooke = 1        self.prunefunc = prunefunc        self.history = []        self.benchmark = []        self.iter = 0        self.time = timewrap()        self.functions = {}        self.function_timings = []        self.calls = []        self.costfun = self.net.Cost        self.stopnow = False    def AddFunction(self, name, func, timing):        if not Helpers.IsSequence(timing):            timing = (timing, 0)        self.function_timings.append((timing, name))        self.functions[name] = func        def AddCall(self, name, func, iter):        if not Helpers.IsSequence(iter):            iter = (iter,)        for i in iter:            assert isinstance(i, int)            self.calls.append((i, name, func))    def SortCalls(self):        """Sorts self.calls only using first item in each tuple."""        self.calls.sort(lambda a, b: cmp(a[0], b[0]))                def CallFunc(self, func, name=None, verbose=0):        if callable(func):            val = func()        elif len(func) == 1:            val = func[0]()        elif len(func) == 2:            val = func[0](*func[1])        elif len(func) == 3:            val = func[0](*func[1], **func[2])        else:            raise ValueError, "len(func) is not in (1,2,3)"        if (name is not None) and val:            self.HistoryAdd(name, val)            if verbose:                print name+":", val    def HistoryAdd(self, record, data = ()):        if type(data) == type(()):            self.history.append((record, self.iter, self.costfun())+data)        else:            self.history.append((record, self.iter, self.costfun(), data))    def TryPruning(self, *args, **kws):        pruningout = apply(self.prunefunc, (self.net,) + args, kws)        self.HistoryAdd("Pruning", pruningout)        return pruningout    def HookeJeeves(self):        tmp = self.time.clock()        hookeout = self.net.UpdateAllHookeJeeves(            exploresteps=self.exploresteps, returncost = 1)        timeused = self.time.clock()-tmp        self.benchmark.append(timeused)        self.HistoryAdd("HookeJeeves",                        (timeused, self.time.clock()) +  hookeout)        if self.printhooke:            print self.iter, ":", hookeout        self.iter += 1        return hookeout        def Iteration(self, debug = 0):        tmp = self.time.clock()        if debug:            self.net.UpdateAllDebug()        else:            self.net.UpdateAll()        timeused = self.time.clock()-tmp        self.benchmark.append(timeused)        self.HistoryAdd("Iteration", (timeused, self.time.clock()))        self.iter += 1    def CheckDoHooke(self, hooke):        if hooke:            if isinstance(hooke, int):                if self.iter%hooke == hooke/2:                    return 1            elif isinstance(hooke, list):                if len(hooke) and hooke[-1] >= self.iter:                    del hooke[-1]                    return 1            elif callable(hooke):                return hooke(self.iter)            else:                raise TypeError, "Bad type for hooke parameter"        return 0    def LearnNet(self, printcost=0, iters=200,                 hooke=None, exploresteps=1, printhooke=None, raisekbd=SSTOP,                 debug=0, verbosecall=1):        """Learns the net.        Parameters:        printcost=0 - If nonzero causes printing of cost every printcost                      iterations.        iters=200 - Number of iteration where to stop (numbering starts                    where LearnNet() last time stopped).        hooke=None  - integer, list or function                   if None (or not given) use self.usehooke                   if integer != 0, then Hooke-Jeeves is performed when                   iter%hooke==hooke/2                   if list Hooke-Jeeves is performed if (iter in hooke)                   if function HookeJeeves is performde if funciton(iter)        exploresteps=4 - Sent as parameter to UpdateAllHookeJeeves()        """        self.stopnow = False        if debug:            self.costfun = self.net.CostDebug        else:            self.costfun = self.net.Cost        if iters < 0:            iters = self.iter - iters        self.exploresteps=exploresteps        if printhooke is not None:            self.printhooke = printhooke        else:            if printcost == 0:                self.printhooke = 0            else:                self.printhooke = 1        if raisekbd != SIGNORE:            global sigintencountered            sigintencountered = 0            oldsiginthandler = signal.signal(signal.SIGINT, siginthandler)        if printcost:            print self.iter, ":", self.costfun()        if type(hooke) == type([]):            hooke = hooke[:]            hooke.sort()            hooke.reverse()            #TODO some better find algorithm            while len(hooke) and hooke[-1] < iter:                del hooke[-1]        self.SortCalls()        while self.iter < iters:            if (raisekbd != SIGNORE) and sigintencountered:                print "Learning stopped with ctrl-C"                sigintencountered = 0                if raisekbd == SSTOP:                    return 1                else:                    raise KeyboardInterrupt            if self.CheckDoHooke(hooke):                self.HookeJeeves()            else:                self.Iteration(debug)            while len(self.calls) > 0 and self.calls[0][0] == self.iter:                self.CallFunc(self.calls[0][2],                              self.calls[0][1],                              verbose=verbosecall)                del self.calls[0]            for t in self.function_timings:                if (self.iter%t[0][0] == t[0][1]):                    self.CallFunc(self.functions[t[1]], t[1])            if printcost and (self.iter%printcost == 0):                if debug:                    debugstring = "(debug) "                else:                    debugstring = ""                print "%s%d : %f" % (debugstring, self.iter, self.costfun())            if self.stopnow:                break        if raisekbd != SIGNORE:            signal.signal(signal.SIGINT, oldsiginthandler)            if sigintencountered:                print "Learning ready when ctrl-C was sent"                sigintencountered = 0                if raisekbd:                    raise KeyboardInterrupt        return 0class LearnerOL(Learner):    def __init__(self, net, stepperfunc, stepperargs=(), epsilon=1e-2):        self.net = net        self.history = []        self.benchmark = []        self.timestep = 0        self.iter = 0        self.time = timewrap()        self.epsilon = epsilon        self.stepperfunc = stepperfunc        self.stepperargs = stepperargs        self.functions = {}        self.function_timings = []    def HistoryAdd(self, record, data = ()):        if type(data) == type(()):            self.history.append(                (record, self.timestep, self.iter, self.net.Cost())+data)        else:            self.history.append(                (record, self.timestep, self.iter, self.net.Cost(), data))    def PrintCost(self):        print `self.timestep` + "(" + `self.iter` + "):", self.net.Cost()        def Iteration(self, printiters = 0):        self.net.UpdateTimeDep()        if printiters and (self.iter%printiters == 0):            self.PrintCost()        self.iter += 1    def StepTime(self, printsteps = 1):        self.net.UpdateTimeInd()        if printsteps and (self.timestep%printsteps == 0):            self.PrintCost()        self.iter = 0        self.timestep += 1        self.stepperfunc(*self.stepperargs)    def DoTimeStep(self, printiters = 0, printsteps = 1):        oldcost = 1e300        cost = self.net.Cost()        while (oldcost - cost) > self.epsilon:            self.Iteration(printiters = printiters)            oldcost = cost            cost = self.net.Cost()        self.StepTime(printsteps = printsteps)    def LearnNet(self, stopstep, printiters = 0, printsteps = 1,                 raisekbd = SSTOP):        if stopstep < 0:            stopstep = self.timestep - stopstep        if raisekbd not in (SIGNORE,SSTOP,SRAISE):            raise ValueError, "Unkown raisekbd value"        if raisekbd != SIGNORE:            global sigintencountered            sigintencountered = 0            oldsiginthandler = signal.signal(signal.SIGINT, siginthandler)        while self.timestep < stopstep:            if (raisekbd != SIGNORE) and sigintencountered:                print "Learning stopped with ctrl-C"                sigintencountered = 0                if raisekbd == SSTOP:                    return 1                elif raisekbd == SRAISE:                    raise KeyboardInterrupt                else:                    assert(0)            self.DoTimeStep(printiters = printiters, printsteps = printsteps)            for t in self.function_timings:                if (self.timestep%t[0][0] == t[0][1]):                    self.CallFunc(self.functions[t[1]], t[1])        if raisekbd != SIGNORE:            signal.signal(signal.SIGINT, oldsiginthandler)            if sigintencountered:                print "Learning ready when ctrl-C was sent"                sigintencountered = 0                if raisekbd == SRAISE:                    raise KeyboardInterrupt

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -