⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 node.h

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 H
📖 第 1 页 / 共 3 页
字号:
  DSSet *sstate, *sstep;  double cost;  bool exuptodate;};class RectifiedGaussian : public Variable, public BiParNode{public:  RectifiedGaussian(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE  RectifiedGaussian(Net *net, NetLoader *loader);#endif  double Cost();  bool GetReal(DSSet &val, DFlags req);  /* Returns the actual posterior parameters. */  bool GetMyval(DSSet &val);  void GradReal(DSSet &val, const Node *ptr);  string GetType() const;  void Save(NetSaver *saver);  void GetState(DV *state, size_t t);  void SetState(DV *state, size_t t);protected:  virtual void MyUpdate();  void MyPartialUpdate(IntV *indices);  void UpdateExpectations();  /* Parameters of the rectified Gaussian posterior approximation.      For debug purposes. */  DSSet myval;  /* Expectations (stored to gain speed).     Note that the posterior mean- or variance parameter is not     the same as the mean or variance because the posterior is     approximated with a rectified Gaussian. */  DSSet expectations;  double cost;};class RectifiedGaussianV : public Variable, public BiParNode{public:  RectifiedGaussianV(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE  RectifiedGaussianV(Net *net, NetLoader *loader);#endif  double Cost();  bool GetRealV(DVH &val, DFlags req);  bool GetMyvalV(DVH &val);  void GradReal(DSSet &val, const Node *ptr);  void GradRealV(DVSet &val, const Node *ptr);  string GetType() const;  void Save(NetSaver *saver);protected:  void MyUpdate();  void UpdateExpectations();  DVSet myval;  DVSet expectations;  double cost;};class GaussRect : public Variable, public BiParNode{public:  GaussRect(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE  GaussRect(Net *net, NetLoader *loader);#endif  double Cost();  bool GetReal(DSSet &val, DFlags req);  bool GetRectReal(DSSet &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  string GetType() const;  void Save(NetSaver *saver);protected:  void MyUpdate();  void UpdateMoments();  void UpdateExpectations();  void ChildGradients(DSSet &norm, DSSet &rect);  DSSet posval;  DSSet negval;  double posweight;  double negweight;  vector<double> posmoments;  vector<double> negmoments;  DSSet expts;  DSSet rectexpts;  double cost;};class GaussRectV : public Variable, public BiParNode{public:  friend class GaussRectVState;  GaussRectV(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE  GaussRectV(Net *net, NetLoader *loader);#endif  double Cost();  bool GetRealV(DVH &val, DFlags req);  bool GetRectRealV(DVH &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void GradRealV(DVSet &val, const Node *ptr);  string GetType() const;  void Save(NetSaver *saver);  void GetState(DV *state, size_t t);  void SetState(DV *state, size_t t);protected:  void MyUpdate();  void MyPartialUpdate(IntV *indices);  void UpdateMoments();  void UpdateExpectations();  void ChildGradients(DVSet &norm, DVSet &rect);  DVSet posval;  DVSet negval;  DV posweights;  DV negweights;  vector<DV> posmoments;  vector<DV> negmoments;  DVSet expts;  DVSet rectexpts;  double cost;};// Making the internals of GaussRectV public does not// seem temptating but writing unittests without them// is impossible. Hence, GaussRectVState (a friend of GaussRectV)// provides access to the internals of GaussRectV without// cluttering the interface of GaussRectV.class GaussRectVState{public:  GaussRectVState(GaussRectV *n);  DVSet &GetPosVal();  DVSet &GetNegVal();  DV &GetPosWeights();  DV &GetNegWeights();  DV &GetPosMoment(int i);  DV &GetNegMoment(int i);private:  GaussRectV *node;};class MoG : public Variable, public NParNode{public:  MoG(Net *net, Label label, Node *d);#ifndef BUILDING_SWIG_INTERFACE  MoG(Net *net, NetLoader *loader);#endif    double Cost();  bool GetReal(DSSet &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void GradDiscrete(DD &val, const Node *ptr);  string GetType() const;  void Save(NetSaver *saver);  void AddComponent(Node *m, Node *v);  size_t NumComponents();protected:  void MyUpdate();  vector<DSSet*> myval;  vector<Node*> means;  vector<Node*> vars;private:  bool IsMeanParent(const Node *ptr);  bool IsVarParent(const Node *ptr);  int WhichMeanParent(const Node *ptr);  int WhichVarParent(const Node *ptr);  int WhichParent(const Node *ptr, const vector<Node*> &parents);  void ComputeExpectations();  DSSet expts;  size_t numComponents;  double cost;};class MoGV : public Variable, public NParNode{public:  MoGV(Net *net, Label label, Node *d);#ifndef BUILDING_SWIG_INTERFACE  MoGV(Net *net, NetLoader *loader);#endif  double Cost();  bool GetRealV(DVH &val, DFlags req);  void GetMyvalV(DVH &val, int k);  void GradReal(DSSet &val, const Node *ptr);  void GradRealV(DVSet &val, const Node *ptr);  void GradDiscreteV(VDD &val, const Node *ptr);  string GetType() const;  void Save(NetSaver *saver);    /* Parents MUST be added with this method. */  void AddComponent(Node *m, Node *v);  size_t NumComponents();protected:  void MyUpdate();  bool MyClamp(const DV &m);  /* Posterior parameters (weights are got from Categorical) */  vector<DVSet*> myval;  vector<Node*> means;  vector<Node*> vars;private:  bool IsMeanParent(const Node *ptr);  bool IsVarParent(const Node *ptr);  int WhichMeanParent(const Node *ptr);  int WhichVarParent(const Node *ptr);  int WhichParent(const Node *ptr, const vector<Node*> &parents);  /* Updates expts. */  void ComputeExpectations();  /* Expectations calculated from the posterior. */  DVSet expts;  /* Number of mixture components. */  size_t numComponents;  double cost;};class Dirichlet : public Variable, public NParNode{public:  Dirichlet(Net *net, Label label, ConstantV *n);#ifndef BUILDING_SWIG_INTERFACE  Dirichlet(Net *net, NetLoader *loader);#endif    double Cost();  /* Returns expectations of different components.    <log c_i> is in ex field, naturally. */  bool GetRealV(DVH &val, DFlags req);    string GetType() const;  void Save(NetSaver *saver);protected:  void MyUpdate();private:  /* Updates expts. */  void ComputeExpectations();  /* Posterior parameters. */  DV myval;  /* Expectations calculated from the posterior. */  DVSet expts;  /* Number of components. */  size_t numComponents;  double cost;};class DiscreteDirichlet : public Variable, public NParNode{public:  DiscreteDirichlet(Net *net, Label label, Dirichlet *n);#ifndef BUILDING_SWIG_INTERFACE  DiscreteDirichlet(Net *net, NetLoader *loader);#endif  double Cost();  bool GetDiscrete(DD *&val);  void GradRealV(DVSet &val, const Node *ptr);  string GetType() const;  void Save(NetSaver *saver);protected:  void MyUpdate();  bool MyClamp(const DD &v);  DD myval;  double cost;};/* A discrete variable with dirichlet prior for its prior weights */class DiscreteDirichletV : public Variable, public NParNode{public:  DiscreteDirichletV(Net *net, Label label, Dirichlet *n);#ifndef BUILDING_SWIG_INTERFACE  DiscreteDirichletV(Net *net, NetLoader *loader);#endif  double Cost();  bool GetDiscreteV(VDDH &val);  void GradRealV(DVSet &val, const Node *ptr);  string GetType() const;  void Save(NetSaver *saver);protected:  void MyUpdate();  bool MyClamp(const VDD &v);  VDD myval;  double cost;};class Rectification : public Function, public UniParNode{public:  Rectification(Net *net, Label label, Node *n);#ifndef BUILDING_SWIG_INTERFACE  Rectification(Net *net, NetLoader *loader);#endif  bool GetReal(DSSet &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const;};class RectificationV : public Function, public UniParNode{public:  RectificationV(Net *net, Label label, Node *n);#ifndef BUILDING_SWIG_INTERFACE  RectificationV(Net *net, NetLoader *loader);#endif  bool GetRealV(DVH &val, DFlags req);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const;};class ProdV : public Function, public BiParNode{public:  ProdV(Net *ptr, Label label, Node *n1, Node *n2) :     Function(ptr, label, n1, n2), BiParNode(n1, n2) {}#ifndef BUILDING_SWIG_INTERFACE  ProdV(Net *ptr, NetLoader *loader);#endif  void GradReal(DSSet &val, const Node *ptr);  bool GetRealV(DVH &val, DFlags req);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "ProdV"; }private:  DVSet myval;};class Sum2V : public Function, public BiParNode{public:  Sum2V(Net *ptr, Label label, Node *n1, Node *n2) :    Function(ptr, label, n1, n2), BiParNode(n1, n2) {    persist = 4 | 8; // Sum2V needs at least one child and cuts off if                     // there is only one parent  }#ifndef BUILDING_SWIG_INTERFACE  Sum2V(Net *ptr, NetLoader *loader);#endif  void GradReal(DSSet &val, const Node *ptr);  bool GetRealV(DVH &val, DFlags req);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "Sum2V"; }private:  DVSet myval;};class SumNV : public Function, public NParNode{public:  SumNV(Net *net, Label label) :     Function(net, label)  {    persist = 4 | 8; // SumN needs at least one child and cuts off if                     // there is only one parent    keepupdated = false;  }#ifndef BUILDING_SWIG_INTERFACE  SumNV(Net *net, NetLoader *loader);#endif  bool AddParent(Node *n);  bool GetRealV(DVH &val, DFlags req);  void GradReal(DSSet &val, const Node *ptr);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "SumNV"; }  void Outdate(const Node *ptr);  void SetKeepUpdated(const bool _keepupdated);private:  void UpdateFromScratch(DFlags req);  DVSet myval;  vector<DVSet> parentval;  bool keepupdated;};class DelayV : public Function, public BiParNode{public:  DelayV(Net *ptr, Label label, Node *n1, Node *n2) :     Function(ptr, label, n1, n2), BiParNode(n1, n2)   {    lendelay = 1;  }#ifndef BUILDING_SWIG_INTERFACE  DelayV(Net *ptr, NetLoader *loader);#endif  void GradReal(DSSet &val, const Node *ptr);  bool GetRealV(DVH &val, DFlags req);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "DelayV"; }  int GetDelayLength();  void SetDelayLength(int len);private:  DVSet myval;  int lendelay;};class GaussianV : public Variable, public BiParNode{public:  GaussianV(Net *_net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE  GaussianV(Net *_net, NetLoader *loader);#endif  ~GaussianV() {    if (sstate) delete sstate;    if (sstep) delete sstep;  }  double Cost();  void GradReal(DSSet &val, const Node *ptr);  bool GetRealV(DVH &val, DFlags req);  void GradRealV(DVSet &val, const Node *ptr);  void Save(NetSaver *saver);  string GetType() const { return "GaussianV"; }  void GetState(DV *state, size_t t);  void SetState(DV *state, size_t t);protected:  bool MyClamp(double m);  bool MyClamp(double m, double v);  bool MyClamp(const DV &m);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -