⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 node.h

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 H
📖 第 1 页 / 共 3 页
字号:
  bool MyClamp(const DV &m, const DV &v);  void MyUpdate();  void MyPartialUpdate(IntV *indices);  bool MySaveState();  bool MySaveStep();  bool MySaveRepeatedState(double alpha);  void MyRepeatStep(double alpha);  bool MyClearStateAndStep();  DVSet myval;  double cost;  bool exuptodate;private:  DVSet *sstate, *sstep;};class SparseGaussV : public GaussianV{public:  SparseGaussV(Net *_net, Label label, Node *m, Node *v) :     GaussianV(_net, label, m, v) {}#ifndef BUILDING_SWIG_INTERFACE  SparseGaussV(Net *_net, NetLoader *loader);#endif  double Cost();  void Update();  void GradReal(DSSet &val, const Node *ptr);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "SparseGaussV"; }  void SparseClampDV(const DV &mean, const IntV &mis);  IntV& GetMissing() { return missing; }  void SetMissing(IntV &mis);private:  IntV missing;};class DelayGaussV : public Variable, public NParNode{public:  DelayGaussV(Net *_net, Label label, Node *m, Node *v, Node *a,	      Node *m0, Node *v0);#ifndef BUILDING_SWIG_INTERFACE  DelayGaussV(Net *_net, NetLoader *loader);#endif  double Cost();  void GradReal(DSSet &val, const Node *ptr);  bool GetRealV(DVH &val, DFlags req);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "DelayGaussV"; }protected:  bool MyClamp(double m)  {    fill(myval.mean.begin(), myval.mean.end(), m);    fill(myval.var.begin(), myval.var.end(), 0);    exuptodate = false;    return true;  }  bool MyClamp(const DV &m)  {    if (m.size() == myval.mean.size())      copy(m.begin(), m.end(), myval.mean.begin());    else {      ostringstream msg;      msg << "DelayGaussV::MyClamp: wrong vector size " << m.size() << " != "	  << myval.mean.size();      throw TypeException(msg.str());    }    fill(myval.var.begin(), myval.var.end(), 0);    return true;  }  bool MyClamp(const DV &m, const DV &v)  {    if (m.size() == myval.mean.size() && v.size() == myval.var.size()) {      copy(m.begin(), m.end(), myval.mean.begin());      copy(v.begin(), v.end(), myval.var.begin());    } else {      ostringstream msg;      msg << "DelayGaussV::MyClamp: wrong vector size " << m.size() << " != "	  << myval.mean.size();      throw TypeException(msg.str());    }    return true;  }  void MyUpdate();  bool MySaveState();  bool MySaveStep();  bool MySaveRepeatedState(double alpha);  void MyRepeatStep(double alpha);  bool MyClearStateAndStep();private:  DVSet myval;  DVSet *sstate, *sstep;  double cost;  bool exuptodate;};class GaussNonlin : public Variable, public BiParNode// nonlinearity: myval2 = exp(-myval1*myval1){public:  GaussNonlin(Net *_net, Label label, Node *m, Node *v) :     Variable(_net, label, m, v), BiParNode(m, v)  {    sstate = 0; sstep = 0;    cost = 0;    CheckParent(0, REAL_MV);    CheckParent(1, REAL_ME);    MyUpdate();  }#ifndef BUILDING_SWIG_INTERFACE  GaussNonlin(Net *_net, NetLoader *loader);#endif  ~GaussNonlin() {    if (sstate) delete sstate;    if (sstep) delete sstep;  }  double Cost();  bool GetReal(DSSet &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "GaussNonlin"; }protected:  bool MyClamp(double m) {    myval1.mean = m;    myval1.var = 0;    meanuptodate = false;    varuptodate = false;    return true;  }  void MyUpdate();  void UpdateMean();  void UpdateVar();  bool MySaveState();  bool MySaveStep();  bool MySaveRepeatedState(double alpha);  void MyRepeatStep(double alpha);  bool MyClearStateAndStep();private:  DSSet myval1, myval2;  // 1 before nonlinearity, 2 after  DSSet *sstate, *sstep;  double cost;  bool meanuptodate, varuptodate;  // refer to myval2};class GaussNonlinV : public Variable, public BiParNode// nonlinearity: myval2 = exp(-myval1*myval1){public:  GaussNonlinV(Net *_net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE  GaussNonlinV(Net *_net, NetLoader *loader);#endif  ~GaussNonlinV() {    if (sstate) delete sstate;    if (sstep) delete sstep;  }  double Cost();  void GradReal(DSSet &val, const Node *ptr);  bool GetRealV(DVH &val, DFlags req);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "GaussNonlinV"; }protected:  bool MyClamp(double m)  {    fill(myval1.mean.begin(), myval1.mean.end(), m);    fill(myval1.var.begin(), myval1.var.end(), 0);    meanuptodate = false;    varuptodate = false;    return true;  }  bool MyClamp(const DV &m)  {    if (m.size() == myval1.mean.size())      copy(m.begin(), m.end(), myval1.mean.begin());    else {      ostringstream msg;      msg << "GaussianV::MyClamp: wrong vector size " << m.size() << " != "	  << myval1.mean.size();      throw TypeException(msg.str());    }    fill(myval1.var.begin(), myval1.var.end(), 0);    meanuptodate = false;    varuptodate = false;    return true;  }  void MyUpdate();  void UpdateMean();  void UpdateVar();  bool MySaveState();  bool MySaveStep();  bool MySaveRepeatedState(double alpha);  void MyRepeatStep(double alpha);  bool MyClearStateAndStep();private:  DVSet myval1, myval2;  // 1 before nonlinearity, 2 after  DVSet *sstate, *sstep;  double cost;  bool meanuptodate, varuptodate;  // refer to myval2};class Discrete : public Variable, public NParNode{public:  Discrete(Net *_net, Label label, Node *n=0) :     Variable(_net, label, n), NParNode(n)  {    cost = 0; exsum = 0;    if (n) {      CheckParent(0, REAL_ME);      exuptodate = false;      MyUpdate();    }  }#ifndef BUILDING_SWIG_INTERFACE  Discrete(Net *_net, NetLoader *loader);#endif  bool AddParent(Node *n) {    Node::AddParent(n);    CheckParent(NumParents()-1, REAL_ME);    MyUpdate();    return true;  }  double Cost();  void GradReal(DSSet &val, const Node *ptr);#ifdef BUILDING_SWIG_INTERFACE  BOOLASOBJ GetDiscrete(DD *&val);#else  bool GetDiscrete(DD *&val);#endif  void Save(NetSaver *saver);  string GetType() const { return "Discrete"; }protected:  bool MyClamp(double m) { return false; }  bool MyClamp(const DD &m) { myval = m; return true; }  bool MyClamp(int n) {     if (n >= (int)NumParents()) {      throw TypeException("Too large value for clamping a Discrete");    }    myval.Resize(NumParents());    for (size_t j=NumParents(); j>0; j--) {      myval[j-1] = 0;    }    myval[n] = 1;    return true;  }  void MyUpdate();  void UpdateExpSum();private:  DD myval;  double cost, exsum;  bool exuptodate;};class DiscreteV : public Variable, public NParNode{public:  DiscreteV(Net *_net, Label label, Node *n=0);#ifndef BUILDING_SWIG_INTERFACE  DiscreteV(Net *_net, NetLoader *loader);#endif  bool AddParent(Node *n) {    DVH tmp;    if (! n->GetRealV(tmp, DFlags(true, false, true))) {      ostringstream msg;      msg << "Wrong type of parents in " << GetType() << " Node "          << label << std::endl;      msg << " Parent " << n->GetLabel() << ":" << n->GetType();      throw StructureException(msg.str());    }    Node::AddParent(n);    MyUpdate();    return true;  }  double Cost();  void GradReal(DSSet &val, const Node *ptr);  void GradRealV(DVSet &val, const Node *ptr);  bool GetDiscreteV(VDDH &val);  void Save(NetSaver *saver);  string GetType() const { return "DiscreteV"; }protected:  bool MyClamp(double m) { return false; }  bool MyClamp(const VDD &m) { myval = m; return true; }  void MyUpdate();  void UpdateExpSum();private:  VDD myval;  double cost;  DV exsum;  bool exuptodate;};class Memory : public Variable, public UniParNode{public:  Memory(Net *_net, Label label, Node *n) :     Variable(_net, label, n), UniParNode(n)  {    if (n->TimeType()) {      ostringstream msg;      msg << GetIdent() << ": parent must be independent of time";      throw StructureException(msg.str());    }    timetype = 2;    oldcost = 0; cost = 0;  }#ifndef BUILDING_SWIG_INTERFACE  Memory(Net * net, NetLoader *loader);#endif  void NotifyTimeType(int tt, int verbose=0)  {    if (GetParent(0)->TimeType()) {      ostringstream msg;      msg << GetIdent() << ": parent must be independent of time";      throw StructureException(msg.str());    }  }  double Cost();  void MyUpdate();  bool GetReal(DSSet &val, DFlags req) {return ParReal(0, val, req);}  void GradReal(DSSet &grad, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "Memory"; }  void Outdate(const Node *ptr) { costuptodate = false; OutdateChild(); }  DSSet oldval;  double oldcost;  double cost;};class OnLineDelay : public Node{public:  virtual void Save(NetSaver *saver);  virtual void StepTime() = 0;  virtual void ResetTime() = 0;protected:  OnLineDelay(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);#ifndef BUILDING_SWIG_INTERFACE  OnLineDelay(Net *ptr, NetLoader *loader);#endif};class OLDelayS : public OnLineDelay, public BiParNode{public:  OLDelayS(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0) :    OnLineDelay(ptr, label, n1, n2), BiParNode(n1, n2)  {    CheckParent(0, REAL_M);    CheckParent(1, REAL_M);    DSSet tmp;    ParReal(0, tmp, DFlags(true));    oldmean = tmp.mean;    exuptodate = false;  }#ifndef BUILDING_SWIG_INTERFACE  OLDelayS(Net *ptr, NetLoader *loader);#endif  virtual void Save(NetSaver *saver);  virtual void StepTime();  virtual void ResetTime();  virtual bool GetReal(DSSet &val, DFlags req);  string GetType() const { return "OLDelayS"; }private:  double oldmean;  double oldexp;  bool exuptodate;};class OLDelayD : public OnLineDelay, public BiParNode{public:  OLDelayD(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0) :    OnLineDelay(ptr, label, n1, n2), BiParNode(n1, n2)  {    CheckParent(0, DISCRETE);    CheckParent(1, DISCRETE);    DD *tmp;    ParDiscrete(0, tmp);    oldval = *tmp;  }#ifndef BUILDING_SWIG_INTERFACE  OLDelayD(Net *ptr, NetLoader *loader);#endif  virtual void Save(NetSaver *saver);  virtual void StepTime();  virtual void ResetTime();#ifdef BUILDING_SWIG_INTERFACE  virtual BOOLASOBJ GetDiscrete(DD *&val);#else  virtual bool GetDiscrete(DD *&val);#endif  string GetType() const { return "OLDelayD"; }private:  DD oldval;};class Proxy : public Node, public UniParNode{public:  Proxy(Net *ptr, Label label, Label rlabel);#ifndef BUILDING_SWIG_INTERFACE  Proxy(Net *ptr, NetLoader *loader);#endif  void Save(NetSaver *saver);  string GetType() const { return "Proxy"; }  bool GetReal(DSSet &val, DFlags req);  bool GetRealV(DVH &val, DFlags req);#ifdef BUILDING_SWIG_INTERFACE  BOOLASOBJ GetDiscrete(DD *&val);#else  bool GetDiscrete(DD *&val);#endif  bool GetDiscreteV(VDDH &val);  void GradReal(DSSet &val, const Node *ptr) { ChildGradReal(val); }  void GradRealV(DVSet &val, const Node *ptr) { ChildGradRealV(val); }  void GradDiscrete(DD &val, const Node *ptr) { ChildGradDiscrete(val); }  void GradDiscreteV(VDD &val, const Node *ptr) { ChildGradDiscreteV(val); }  bool CheckRef();private:  string reflabel;  bool req_discrete, req_discretev;  DFlags real_flags, realv_flags;};class Evidence : public Variable, public Decayer, public UniParNode{public:  Evidence(Net *ptr, Label label, Node *p) :    Variable(ptr, label, p), Decayer(ptr), UniParNode(p)  {    alpha = 1e-10;    decay = 0;    myval = 0;    cost = 0;  }#ifndef BUILDING_SWIG_INTERFACE  Evidence(Net *ptr, NetLoader *loader);#endif  void Save(NetSaver *saver);  string GetType() const { return "Evidence"; }  void GradReal(DSSet &val, const Node *ptr);  double Cost();  void SetDecayTime(double iters) { decay = alpha / iters; }  virtual bool DoDecay(string hook);private:  void MyUpdate() {}  bool MyClamp(double mean, double var);  double cost;  double myval;  double alpha;  double decay;};class EvidenceV : public Variable, public Decayer, public UniParNode{public:  EvidenceV(Net *ptr, Label label, Node *p);#ifndef BUILDING_SWIG_INTERFACE  EvidenceV(Net *ptr, NetLoader *loader);#endif  void Save(NetSaver *saver);  string GetType() const { return "EvidenceV"; }  void GradRealV(DVSet &val, const Node *ptr);  double Cost();  void SetDecayTime(const DV &iters);  virtual bool DoDecay(string hook);private:  void MyUpdate() {}  bool MyClamp(double mean, double var);  bool MyClamp(const DV &mean, const DV &var);  double cost;  DV myval;  DV alpha;  DV decay;};#endif // NODE_H

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -