⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 node.h

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 H
📖 第 1 页 / 共 3 页
字号:
// -*- C++ -*-//// This file is a part of the Bayes Blocks library//// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander// Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.//// This program is free software; you can redistribute it and/or modify// it under the terms of the GNU General Public License as published by// the Free Software Foundation; either version 2, or (at your option)// any later version.//// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the// GNU General Public License (included in file License.txt in the// program package) for more details.//// $Id: Node.h 7 2006-10-26 10:26:41Z ah $#ifndef NODE_H#define NODE_H#include <string>#include <map>#include "Templates.h"#include "Saver.h"#include "Loader.h"#include "Decay.h"#include "Net.h"class Node;#ifndef BUILDING_SWIG_INTERFACEtypedef bool BOOLASOBJ;#endifenum partype_e {  REAL_MV, REAL_ME, REAL_M, REALV_MV, REALV_ME, REALV_M, DISCRETE,  DISCRETEV};class NodeBase{public:  virtual ~NodeBase() { }  virtual int ParIdentity(const Node *ptr) = 0;  virtual size_t NumParents() = 0;  virtual Node *GetParent(size_t i) = 0;  virtual int RemoveParent(const Node *ptr) = 0;protected:  virtual void ReallyAddParent(Node *ptr) = 0;  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) = 0;};class Node : public virtual NodeBase{public:  friend Net::Net(NetLoader *loader);  virtual ~Node() { }  void NotifyDeath(Node *ptr, int verbose = 0);  virtual void NotifyTimeType(int tt, int verbose = 0);  void ReplacePtr(Node *oldptr, Node *newptr);  void AddChild(Node *ptr) {children.push_back(ptr);}protected:  Node(Net *ptr, Label label);#ifndef BUILDING_SWIG_INTERFACE  Node(Net *ptr, NetLoader *loader, bool isproxy = 0);#endif  void AddParent(Node *ptr, bool really=true);public:  virtual bool GetReal(DSSet &val, DFlags req) { return false; }  virtual void GradReal(DSSet &val, const Node *ptr) {}  virtual bool GetRealV(DVH &val, DFlags req) {    val.vec = 0; return GetReal(val.scalar, req); }  virtual void GradRealV(DVSet &val, const Node *ptr) {}#ifdef BUILDING_SWIG_INTERFACE  virtual BOOLASOBJ GetDiscrete(DD *&val) { return false; }#else  virtual bool GetDiscrete(DD *&val) { return false; }#endif  virtual void GradDiscrete(DD &val, const Node *ptr) {}  virtual bool GetDiscreteV(VDDH &val) {    val.vec = 0; return GetDiscrete(val.scalar); }  virtual void GradDiscreteV(VDD &val, const Node *ptr) {}  virtual void Outdate(const Node *ptr) { OutdateChild(); }  void CheckParent(size_t parnum, partype_e partype);  bool ParReal(int i, DSSet &val, const DFlags req) {    return GetParent(i)->GetReal(val, req);}  bool ParRealV(int i, DVH &val, const DFlags req) {    return GetParent(i)->GetRealV(val, req);}  bool ParDiscrete(int i, DD *&val) {    return GetParent(i)->GetDiscrete(val);}  bool ParDiscreteV(int i, VDDH &val) {    return GetParent(i)->GetDiscreteV(val);}  void ChildGradReal(DSSet &val);  void ChildGradRealV(DVSet &val);  void ChildGradDiscrete(DD &val);  void ChildGradDiscreteV(VDD &val);  Label GetLabel() const { return label; }  string GetIdent() const { return GetType() + " node " + GetLabel(); }  Net *GetNet() const { return net; }  virtual string GetType() const = 0;  int TimeType() { return timetype; }  int GetDying() { return dying; }  void Die(int verbose = 0);  void OutdateChild();  virtual void Save(NetSaver *saver);  size_t NumChildren() { return children.size(); }  Node *GetChild(size_t i) {return i < children.size() ? children[i] : 0;}  int GetPersist() { return persist; }  void SetPersist(int p) { persist = p; }protected:  vector<Node *> children;  Net *net;  Label label;  int persist, timetype;  bool dying;};class NullParNode : public virtual NodeBase{public:  virtual int ParIdentity(const Node *ptr) {return -1;}  virtual size_t NumParents() { return 0; }  virtual Node *GetParent(size_t i) {return 0;}  virtual int RemoveParent(const Node *ptr) {return 0;}protected:  virtual void ReallyAddParent(Node *ptr) {return;}  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) {return false;}};class UniParNode : public virtual NodeBase{private:  Node *parent;public:  UniParNode(Node *p) : parent(p) {}  virtual int ParIdentity(const Node *ptr) { return ptr==parent ? 0 : -1;}  virtual size_t NumParents() { return parent!=0; }  virtual Node *GetParent(size_t i) {return i==0 ? parent : 0;}  virtual int RemoveParent(const Node *ptr);protected:  virtual void ReallyAddParent(Node *ptr) { parent = ptr; }  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) {    return (parent==oldptr) ? (parent=newptr) : false; }};class BiParNode : public virtual NodeBase{private:  Node *parents[2];public:  BiParNode(Node *p1, Node *p2) { parents[0]=p1; parents[1]=p2; }  virtual int ParIdentity(const Node *ptr);  virtual size_t NumParents() { return parents[0] == 0 ? 0 :      (parents[1] == 0 ? 1 : 2); }  virtual Node *GetParent(size_t i) {return i < 2 ? parents[i] : 0;}  virtual int RemoveParent(const Node *ptr);protected:  virtual void ReallyAddParent(Node *ptr);  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr);};class NParNode : public virtual NodeBase{private:  vector<Node *> parents;  map<const Node *, int> parent_inds;public:  NParNode(Node *p1=0, Node *p2=0, Node *p3=0, Node *p4=0, Node *p5=0) {    if (p1) parents.push_back(p1);    if (p2) parents.push_back(p2);    if (p3) parents.push_back(p3);    if (p4) parents.push_back(p4);    if (p5) parents.push_back(p5);  }  virtual int ParIdentity(const Node *ptr);  virtual size_t NumParents() { return parents.size(); }  virtual Node *GetParent(size_t i) {return i < parents.size() ? parents[i] : 0;}  virtual int RemoveParent(const Node *ptr);protected:  virtual void ReallyAddParent(Node *ptr) { parents.push_back(ptr); }  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr);};class Constant : public Node, public NullParNode{public:  Constant(Net *net, Label label, double v) : Node(net, label) {cval = v;}#ifndef BUILDING_SWIG_INTERFACE  Constant(Net *net, NetLoader *loader);#endif  void NotifyTimeType(int tt, int verbose=0) {}  bool GetReal(DSSet &val, DFlags req) {    if (req.mean) {val.mean = cval; req.mean = false;}    if (req.var) {val.var = 0; req.var = false;}    if (req.ex) {val.ex = exp(cval); req.ex = false;}    return req.AllFalse();  }  void GradReal(DSSet &val, const Node *ptr) {}  void Save(NetSaver *saver);  string GetType() const { return "Constant"; }private:  double cval;};class ConstantV : public Node, public NullParNode{public:  ConstantV(Net *net, Label label, DV v);#ifndef BUILDING_SWIG_INTERFACE  ConstantV(Net *net, NetLoader *loader);#endif  void NotifyTimeType(int tt, int verbose=0) {}  bool GetRealV(DVH &val, DFlags req) {    val.vec = &myval;    req.mean = false;    req.var = false;    req.ex = false;    return req.AllFalse();  }  void Save(NetSaver *saver);  string GetType() const { return "ConstantV"; }private:  DVSet myval;};class Function : public Node{public:  void Outdate(const Node *ptr)   {    uptodate = DFlags(false,false,false);     OutdateChild();  }  virtual void Save(NetSaver *saver);protected:  Function(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);#ifndef BUILDING_SWIG_INTERFACE  Function(Net *ptr, NetLoader *loader);#endif  DFlags uptodate;};class Prod : public Function, public BiParNode{public:  Prod(Net *ptr, Label label, Node *n1, Node *n2) :    Function(ptr, label, n1, n2), BiParNode(n1, n2) {mean = 0.0; var = 0.0;}#ifndef BUILDING_SWIG_INTERFACE  Prod(Net *ptr, NetLoader *loader);#endif  bool GetReal(DSSet &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "Prod"; }private:  double mean, var;};class Sum2 : public Function, public BiParNode{public:  Sum2(Net *ptr, Label label, Node *n1, Node *n2) :     Function(ptr, label, n1, n2), BiParNode(n1, n2) {    persist = 4 | 8; // Sum2 needs at least one child and cuts off if                     // there is only one parent  }#ifndef BUILDING_SWIG_INTERFACE  Sum2(Net *ptr, NetLoader *loader);#endif  bool GetReal(DSSet &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "Sum2"; }private:  DSSet myval;};class SumN : public Function, public NParNode{public:  SumN(Net *net, Label label) :     Function(net, label)  {    persist = 4 | 8; // SumN needs at least one child and cuts off if                     // there is only one parent    keepupdated = false;  }#ifndef BUILDING_SWIG_INTERFACE  SumN(Net *net, NetLoader *loader);#endif  bool AddParent(Node *n);  bool GetReal(DSSet &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "SumN"; }  void Outdate(const Node *ptr);  void SetKeepUpdated(const bool _keepupdated);private:  DSSet myval;  vector<DSSet> parentval;  bool keepupdated;};class Relay : public Function, public UniParNode{public:  Relay(Net *ptr, Label label, Node *n) :    Function(ptr, label, n), UniParNode(n) {}#ifndef BUILDING_SWIG_INTERFACE  Relay(Net *ptr, NetLoader *loader);#endif  bool GetReal(DSSet &val, DFlags req) {return ParReal(0, val, req);}  void GradReal(DSSet &val, const Node *ptr) {ChildGradReal(val);}  void Save(NetSaver *saver);  string GetType() const { return "Relay"; }};class Variable : public Node{public:  virtual double Cost() = 0;  virtual void Update() {    if (!clamped) {       MyUpdate();       OutdateChild();    }  }  virtual void PartialUpdate(IntV *indices) {    if (!clamped) {      MyPartialUpdate(indices);      OutdateChild();    }  }  void Clamp(double val)  {    if (!MyClamp(val)) {      ostringstream msg;      msg << GetIdent() << ": Double clamp not allowed";      throw TypeException(msg.str());    }    clamped = true; costuptodate = false;    OutdateChild();  }  void Clamp(double mean, double var)  {    if (!MyClamp(mean, var)) {      ostringstream msg;      msg << GetIdent() << ": Double double clamp not allowed";      throw TypeException(msg.str());    }    clamped = true; costuptodate = false;    OutdateChild();  }  void Clamp(const DV &val)  {    if (!MyClamp(val)) {      ostringstream msg;      msg << GetIdent() << ": DV clamp not allowed";      throw TypeException(msg.str());    }    clamped = true; costuptodate = false;    OutdateChild();  }  void Clamp(const DV &mean, const DV &var)  {    if (!MyClamp(mean, var)) {      ostringstream msg;      msg << GetIdent() << ": Double DV clamp not allowed";      throw TypeException(msg.str());    }    clamped = true; costuptodate = false;    OutdateChild();  }  void Clamp(const DD &val) {    if (!MyClamp(val)) {      ostringstream msg;      msg << GetIdent() << ": DD clamp not allowed";      throw TypeException(msg.str());    }    clamped = true; costuptodate = false;    OutdateChild();  }  void Clamp(int val) {    if (!MyClamp(val)) {      ostringstream msg;      msg << GetIdent() << ": Int clamp not allowed";      throw TypeException(msg.str());    }    clamped = true; costuptodate = false;    OutdateChild();  }  void Clamp(const VDD &val) {    if (!MyClamp(val)) {      ostringstream msg;      msg << GetIdent() << ": VDD clamp not allowed";      throw TypeException(msg.str());    }    clamped = true; costuptodate = false;    OutdateChild();  }  void Unclamp() {if (clamped) {clamped = false; MyUpdate(); OutdateChild();}}  void SaveState();  void SaveStep();  void RepeatStep(double alpha);  void SaveRepeatedState(double alpha);  void ClearStateAndStep();  virtual void Outdate(const Node *ptr) {costuptodate = false;}  virtual void Save(NetSaver *saver);  int GetHookeFlags() { return hookeflags; }  void SetHookeFlags(int h) { hookeflags = h; }  bool IsClamped() { return clamped; }  // These two methods are ment for copying things from one network  // to another similar one  // The allocation of the DV instance is left to user  // so it can be done in the jurisdiction of Python's GC.  // The DV is resized, so initially it can be of size zero, for example.  virtual void GetState(DV *state, size_t t = 0) {    ostringstream msg;    msg << "GetState not supported by " << GetType();    throw TypeException(msg.str());  }  virtual void SetState(DV *state, size_t t = 0) {    ostringstream msg;    msg << "SetState not supported by " << GetType();    throw TypeException(msg.str());  }protected:  Variable(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);#ifndef BUILDING_SWIG_INTERFACE  Variable(Net *ptr, NetLoader *loader);#endif  virtual bool MyClamp(double val) {return false;}  virtual bool MyClamp(double mean, double var) {return false;}  virtual bool MyClamp(const DV &val) {return false;}  virtual bool MyClamp(const DV &mean, const DV &var) {return false;}  virtual bool MyClamp(const DD &val) {return false;}  virtual bool MyClamp(int val) {return false;}  virtual bool MyClamp(const VDD &val) {return false;}  virtual void MyUpdate() = 0;  virtual bool MySaveState() {return false;}  virtual bool MySaveStep() {return false;}  virtual bool MySaveRepeatedState(double alpha) {return false;}  virtual void MyRepeatStep(double alpha) {}  virtual bool MyClearStateAndStep() {return false; }  virtual void MyPartialUpdate(IntV *indices) {    ostringstream msg;    msg << "Partial updates not supported by " << GetType();    throw StructureException(msg.str());  }      bool clamped, costuptodate;  int hookeflags;};class Gaussian : public Variable, public BiParNode{public:  Gaussian(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE  Gaussian(Net *net, NetLoader *loader);#endif  ~Gaussian() {    if (sstate) delete sstate;    if (sstep) delete sstep;  }  double Cost();  bool GetReal(DSSet &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "Gaussian"; }  void GetState(DV *state, size_t t);  void SetState(DV *state, size_t t);protected:  virtual bool MyClamp(double m);  virtual bool MyClamp(double m, double v);  virtual void MyUpdate();  bool MySaveState();  bool MySaveStep();  bool MySaveRepeatedState(double alpha);  void MyRepeatStep(double alpha);  bool MyClearStateAndStep();  void MyPartialUpdate(IntV *indices);  DSSet myval;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -