⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 node.cc

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 CC
📖 第 1 页 / 共 5 页
字号:
    saver->SetNamedDSSet("myval", myval);  }  Function::Save(saver);}// class SumN : public Function : public Nodebool SumN::GetReal(DSSet &val, DFlags req){  bool needm = req.mean && !uptodate.mean;  bool needv = req.var && !uptodate.var;  bool neede = false; //req.ex && !uptodate.ex;  if (needm || needv || neede) {    if (needm) myval.mean = 0;    if (needv) myval.var = 0;    //    if (neede) myval.ex = 1;    for (size_t i = 0; i<NumParents(); i++) {      DSSet p;      if (!ParReal(i, p, DFlags(needm, needv, neede)))	return false;      if (needm) {	myval.mean += p.mean;      }      if (needv) {	myval.var += p.var;      }      //      if (neede) {      //	myval.ex *= p.ex;      //      }      if (keepupdated) {	parentval[i] = p;      }    }    if (needm) uptodate.mean = true;    if (needv) uptodate.var = true;    //    if (neede) uptodate.ex = true;  }  if (req.mean) {val.mean = myval.mean; req.mean = false;}  if (req.var) {val.var = myval.var; req.var = false;}  //  if (req.ex) {val.ex = myval.ex; req.ex = false;}  return req.AllFalse();}void SumN::GradReal(DSSet &val, const Node *ptr){  //int ide = ParIdentity(ptr);  DSSet grad;  ChildGradReal(grad);  val.mean += grad.mean;  val.var += grad.var;  //  if (grad.ex) {  //    DSSet p, tempself;  //    if (!uptodate.ex) {  //      GetReal(tempself, DFlags(false,false,true));  //    }  //    ParReal(ide, p, DFlags(false, false, true));  //    val.ex += grad.ex * myval.ex / p.ex;  //  }}bool SumN::AddParent(Node *n){  // The parents are checked in GetReal  Node::AddParent(n, true);  int ide = ParIdentity(n);  if (keepupdated) {    parentval.resize(ide+1,DSSet(0.0,0.0,1.0));    parentval[ide] = DSSet(0.0,0.0,1.0); // is this done already above?  }  uptodate = DFlags(false,false,false);  OutdateChild();  return true;}void SumN::SetKeepUpdated(const bool _keepupdated){  keepupdated = _keepupdated;  if (keepupdated)  {    // update now    myval.mean = 0.0;    myval.var = 0.0;    //    myval.ex = 1.0;    for (size_t i = 0; i<NumParents(); i++) {      DSSet p;      ParReal(i, p, DFlags(true,true,false)); //.ex      parentval[i] = p;      myval.mean += p.mean;      myval.var += p.var;      //      myval.ex *= p.ex;    }    uptodate.mean = true;    uptodate.var = true;    //    uptodate.ex = true;  }}void SumN::Outdate(const Node *ptr) {  if (uptodate.mean || uptodate.var) { //.ex    if (keepupdated) {      int ide = ParIdentity(ptr);      DSSet p;      ParReal(ide, p, DFlags(true, true, false)); //.ex      myval.mean += p.mean - parentval[ide].mean;      myval.var += p.var - parentval[ide].var;      if (myval.var<0) {	uptodate.var = false;      }      //      myval.ex *= p.ex / parentval[ide].ex;      parentval[ide] = p;    } else {      uptodate = DFlags(false,false,false);     }  }  OutdateChild();}void SumN::Save(NetSaver *saver){  if (saver->GetSaveFunctionValue()) {    saver->SetNamedDSSet("myval", myval);  }  Function::Save(saver);}// class Relay : public Function : public Nodevoid Relay::Save(NetSaver *saver){  Function::Save(saver);}// abstract class Variable : public NodeVariable::Variable(Net *ptr, Label label, Node *n1, Node *n2) :   Node(ptr, label){  persist = 1; // Usually variables need all parents to survive  hookeflags = 0;  net->AddVariable(this, label);  clamped = false;  costuptodate = false;  if (n1)    AddParent(n1, false);  if (n2)    AddParent(n2, false);}void Variable::Save(NetSaver *saver){  saver->SetNamedBool("clamped", clamped);  saver->SetNamedBool("costuptodate", costuptodate);  saver->SetNamedInt("hookeflags", hookeflags);  Node::Save(saver);}void Variable::SaveState() {  if (!clamped && !MySaveState() && (net->GetDebugLevel() > 5))    cerr << "SaveState not supported" << endl;}void Variable::SaveStep() {  if (!clamped && !MySaveStep() && (net->GetDebugLevel() > 5))    cerr << "SaveStep not supported" << endl;}void Variable::RepeatStep(double alpha) {  if (!clamped) { MyRepeatStep(alpha);  OutdateChild(); }}void Variable::SaveRepeatedState(double alpha) {  if (!clamped && !MySaveRepeatedState(alpha) && (net->GetDebugLevel() > 5))    cerr << label << ": SaveRepeatedState not supported" << endl;}void Variable::ClearStateAndStep() {  if (!clamped && !MyClearStateAndStep() && (net->GetDebugLevel() > 5))    cerr << "ClearStateAndStep not supported" << endl;}// class Gaussian : public Variable : public NodeGaussian::Gaussian(Net *net, Label label, Node *m, Node *v) :   Variable(net, label, m, v), BiParNode(m, v){  sstate = 0; sstep = 0;  cost = 0;      CheckParent(0, REAL_MV);  CheckParent(1, REAL_ME);  DSSet p0, p1;  ParReal(0, p0, DFlags(true));  ParReal(1, p1, DFlags(false, false, true));//  myval.mean = p0.mean;//  myval.var = 1/p1.ex;  myval.mean = 0.0;  myval.var = 1.0;  exuptodate = false;  costuptodate = false;}void Gaussian::GetState(DV *state, size_t t = 0){  BBASSERT2(t == 0);  state->resize(2);  (*state)[0] = myval.mean;  (*state)[1] = myval.var;}void Gaussian::SetState(DV *state, size_t t = 0){  BBASSERT2(t == 0);  BBASSERT2(state->size() == 2);  myval.mean = (*state)[0];  myval.var = (*state)[1];  costuptodate = false;  exuptodate = false;    OutdateChild();}double Gaussian::Cost(){  if (!clamped && children.empty())    return 0;  if (!costuptodate) {    DSSet p0, p1;    ParReal(0, p0, DFlags(true, true));    ParReal(1, p1, DFlags(true, false, true));    if (clamped) {      //assert(myval.var == 0);      cost = ((Sqr(myval.mean - p0.mean) + p0.var + myval.var)	      * p1.ex - p1.mean) / 2 + _5LOG2PI;    }    else      cost = ((Sqr(myval.mean - p0.mean) + p0.var + myval.var) *	      p1.ex - p1.mean - log(myval.var) - 1) / 2;    costuptodate = true;  }  return cost;}bool Gaussian::GetReal(DSSet &val, DFlags req){  if (req.ex && !exuptodate) {    myval.ex = exp(myval.mean+myval.var/2);    exuptodate = true;  }  if (req.mean) {val.mean = myval.mean; req.mean = false;}  if (req.var) {val.var = myval.var; req.var = false;}  if (req.ex) {val.ex = myval.ex; req.ex = false;}  return req.AllFalse();}void Gaussian::GradReal(DSSet &val, const Node *ptr){  if (!clamped && children.empty())    return;  if (ParIdentity(ptr)) { // ParMean(1)    DSSet p0;    ParReal(0, p0, DFlags(true, true));    val.mean -= 0.5;    if (clamped) {      //assert(myval.var == 0);      val.ex += (Sqr(myval.mean - p0.mean) + p0.var + myval.var) / 2;    }    else // !clamped      val.ex += (Sqr(myval.mean - p0.mean) + p0.var + myval.var) / 2;  }  else {                  // ParMean(0)    DSSet p0, p1;    ParReal(0, p0, DFlags(true));    ParReal(1, p1, DFlags(false, false, true));    val.mean += (p0.mean - myval.mean) * p1.ex;    val.var += p1.ex / 2;  }}void Gaussian::Save(NetSaver *saver){  saver->SetNamedDSSet("myval", myval);  saver->SetNamedDouble("cost", cost);  saver->SetNamedBool("exuptodate", exuptodate);  if (sstate)    saver->SetNamedDSSet("sstate", *sstate);  if (sstep)    saver->SetNamedDSSet("sstep", *sstep);  Variable::Save(saver);}void VarNewton(double &mean, double &var, double gme, double gva,	       double gex, Label label){  if (!gex) {    var = 0.5 / gva;    mean -= var * gme;  }  else {    // Solve the minimum of gex * exp(mean + var/2) + gme * mean +    // gva * [(mean - current_mean)^2 + var] - 0.5 * log(var)    double oldm = mean, oldv = var, mstep, vstep, coef;    double newc, oldc = gex * exp(mean + var/2) + gme * mean +      gva * var - 0.5 * log(var);    int i = 0;    do {      mstep = -(gme + 2 * gva * (mean - oldm) + gex * exp(mean+var/2)) /	(2 * gva + gex * exp(mean+var/2));      mean += (mstep > MAXSTEP) ? MAXSTEP : mstep;      vstep = 1 / (2 * gva + gex * exp(mean + var/2)) - var;      coef = var * (0.5 - gva * var);      if (coef > 0)	vstep /= 1 + coef;      var += (vstep > MAXSTEP) ? MAXSTEP : vstep;      if (++i >= 100) {	cerr << label << " VarNewton: M=" << oldm << "; V=" << oldv << "; GEX="	     << gex << "; GVA=" << gva << "; GME=" << gme	     << ": mstep = " << mstep << ", vstep = " << vstep << '\n';	mstep = vstep = 0;      }    }    while (fabs(mstep) > MINSTEP || fabs(vstep) > MINSTEP);    newc = gex * exp(mean + var/2) + gme * mean +      gva * (Sqr(mean - oldm) + var) - 0.5 * log(var);    if (newc > oldc + EPSILON)      cerr << label << " VarNewton: M=" << oldm << "; V=" << oldv << "; GEX="	   << gex << "; GVA=" << gva << "; GME=" << gme	   << ": diff = " << newc - oldc << '\n';  }}void Gaussian::MyPartialUpdate(IntV *indices){  MyUpdate();}void Gaussian::MyUpdate(){  if (NumChildren() == 0) {    return;  }  DSSet grad, p0, p1;  ChildGradReal(grad);  ParReal(0, p0, DFlags(true));  ParReal(1, p1, DFlags(false, false, true));  VarNewton(myval.mean, myval.var, (myval.mean - p0.mean) * p1.ex + grad.mean,	    p1.ex/2 + grad.var, grad.ex, label);  exuptodate = false; costuptodate = false;}bool Gaussian::MyClamp(double m){  myval.mean = m; myval.var = 0;  exuptodate = false;  return true;}bool Gaussian::MyClamp(double m, double v){  myval.mean = m; myval.var = v;  exuptodate = false;  return true;}bool Gaussian::MySaveState(){  if (!sstate)    sstate = new DSSet;  sstate->mean = myval.mean;  sstate->var  = log(myval.var);  return true;}bool Gaussian::MySaveStep(){  if (!sstate) return false;  if (!sstep)    sstep = new DSSet;  switch (hookeflags) {  case 0:  case 1:    sstep->mean = myval.mean - sstate->mean;    sstep->var  = log(myval.var) - sstate->var;    break;  case 2:  case 3:    if (sstate->mean == 0)      sstep->mean = -1.0;    else      sstep->mean = myval.mean / sstate->mean;    sstep->var  = log(myval.var) - sstate->var;    break;  }  return true;}bool Gaussian::MySaveRepeatedState(double alpha){  if (!sstate || !sstep) {    cerr << label << ": No saved state when trying to repeat step!" << endl;    return false;  }  switch (hookeflags) {  case 0:  case 2:    sstate->mean = sstate->mean + alpha * sstep->mean;    sstate->var = sstate->var + alpha * sstep->var;    break;  case 1:  case 3:    if (sstep->mean > 0)      sstate->mean = sstate->mean * exp(alpha * log(sstep->mean));    break;  }  return true;}void Gaussian::MyRepeatStep(double alpha){  if (!sstate || !sstep) {    cerr << label << ": No saved state when trying to repeat step!" << endl;    return;  }  switch (hookeflags) {  case 0:    myval.mean = sstate->mean + alpha * sstep->mean;    myval.var = exp(sstate->var + alpha * sstep->var);    break;  case 1:    myval.mean = sstate->mean + alpha * sstep->mean;    break;  case 2:    if (sstep->mean > 0)      myval.mean = sstate->mean * exp(alpha * log(sstep->mean));    myval.var = exp(sstate->var + alpha * sstep->var);    break;  case 3:    if (sstep->mean > 0)      myval.mean = sstate->mean * exp(alpha * log(sstep->mean));    break;  }  exuptodate = false; costuptodate = false;}bool Gaussian::MyClearStateAndStep(){  if (sstate) {    delete sstate;    sstate = 0;  }  if (sstep) {    delete sstep;    sstep = 0;  }  return true;}/* class RectifiedGaussian */RectifiedGaussian::RectifiedGaussian(Net *net, Label label, Node *m, Node *v) :  Variable(net, label, m, v), BiParNode(m, v){  cost = 0;  CheckParent(0, REAL_MV);  CheckParent(0, REAL_ME);  myval.mean = 0;  myval.var = 1;  UpdateExpectations();  MyUpdate();}void RectifiedGaussian::GetState(DV *state, size_t t = 0){  BBASSERT2(t == 0);    state->resize(2);  (*state)[0] = myval.mean;  (*state)[1] = myval.var;}void RectifiedGaussian::SetState(DV *state, size_t t = 0){  BBASSERT2(t == 0 && state->size() == 2);  myval.mean = (*state)[0];  myval.var = (*state)[1];  UpdateExpectations();  costuptodate = false;  OutdateChild();}void RectifiedGaussian::Save(NetSaver *saver){  saver->SetNamedDSSet("myval", myval);  saver->SetNamedDSSet("expectations", expectations);  saver->SetNamedDouble("cost", cost);  Variable::Save(saver);}string RectifiedGaussian::GetType() const {   return "RectifiedGaussian"; }double RectifiedGaussian::Cost(){  if (children.empty()) {    return 0;  }  if (!costuptodate) {    DSSet mpar, vpar;    ParReal(0, mpar, DFlags(true, true, false));    ParReal(1, vpar, DFlags(true, false, true));    /* C_p */    cost = 0.5 * (vpar.ex * (Sqr(expectations.mean - mpar.mean)			     + expectations.var + mpar.var)		  - vpar.mean + log(PI/2)); // + log(2) if m != const0#if RECTIFIED_BETTER_APPROX    cost += log(Erfc(-1/sqrt(2.0) * mpar.mean * 		     exp(vpar.mean / 2 + vpar.var / 8)));#endif    /* C_q */    cost += -1/(2*myval.var) * (expectations.var 				+ Sqr(expectations.mean - myval.mean))      + 0.5 * log(2/(PI*myval.var))       - log(Erfc(-myval.mean / sqrt(2*myval.var)));    costuptodate = true;  }  return cost;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -