⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 node.cc

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 CC
📖 第 1 页 / 共 5 页
字号:
bool RectifiedGaussian::GetMyval(DSSet &val){  val.mean = myval.mean;  val.var = myval.var;  return true;}bool RectifiedGaussian::GetReal(DSSet &val, DFlags req){  if (req.mean) {    val.mean = expectations.mean;  }  if (req.var) {     val.var = expectations.var;  }  /* Clients hoping for ex are out of luck. */  if (req.ex) {    return false;  } else {    return true;  }}void RectifiedGaussian::GradReal(DSSet &val, const Node *ptr){  if (!clamped && children.empty())    return;  if (ParIdentity(ptr) == 1) { // Variance parent    DSSet p0;    ParReal(0, p0, DFlags(true, true));    val.mean -= 0.5;    val.ex += 0.5 * (Sqr(expectations.mean - p0.mean) 		     + expectations.var + p0.var);  } else if (ParIdentity(ptr) == 0) { // Mean parent    DSSet p0, p1;    ParReal(0, p0, DFlags(true));    ParReal(1, p1, DFlags(false, false, true));    val.mean += p1.ex * (expectations.mean - p0.mean);    val.var += 0.5 * p1.ex;#if RECTIFIED_BETTER_APPROX    val.mean += sqrt(2 / PI)      * exp(-0.5 * exp(p1.mean + p1.var / 4) * Sqr(p0.mean) 	    + p1.mean / 2 + p1.var / 8)      / Erfc(-1 / sqrt(2.0) * exp(p1.mean / 2 + p1.var / 8) * p0.mean);#endif  } else {     BBASSERT2(false);  }}void RectifiedGaussian::MyPartialUpdate(IntV *indices){  MyUpdate();}void RectifiedGaussian::MyUpdate(){  DSSet grad, mpar, vpar;;  /* double linvvar, lmean; */  /* From the gradient we can deduce the likelihood and since the     likelihood is known to be Gaussian and the prior is Rectified     Gaussian the Rectified Gaussian posterior approximation matches     exactly to the correct posterior and the parameters can be set     according to the correct posterior. */  ChildGradReal(grad);  ParReal(0, mpar, DFlags(true));  ParReal(1, vpar, DFlags(false, false, true));  /* This is what happens below.  linvvar = 2 * grad.var;  lmean = expectations.mean - grad.mean / linvvar;  myval.var = 1 / (linvvar + vpar.ex);   myval.mean = myval.var * (linvvar * lmean + vpar.ex * mpar.mean);   */  myval.var = 1 / (2*grad.var + vpar.ex);  myval.mean = myval.var * (2*grad.var * expectations.mean - grad.mean			    + vpar.ex * mpar.mean);  /* myvals have changed, expectations needs to be updated */  UpdateExpectations();  costuptodate = false;}void RectifiedGaussian::UpdateExpectations(){  if ((myval.mean / sqrt(myval.var)) > RECTLIMIT) {    double scale;    scale = sqrt(2*myval.var/PI) / Erfcx(-myval.mean / sqrt(2*myval.var));    expectations.mean = myval.mean + scale;    expectations.var = Sqr(myval.mean) + myval.var + scale * myval.mean      - Sqr(expectations.mean);  } else { /* use exponential approximation */    expectations.mean = -myval.var / myval.mean;    expectations.var = Sqr(expectations.mean);  }}/* class RectifiedGaussianV */RectifiedGaussianV::RectifiedGaussianV(Net *net, Label label, 				       Node *m, Node *v) :  Variable(net, label, m, v), BiParNode(m, v){  cost = 0;  myval.mean.resize(net->Time());   myval.var.resize(net->Time());  expectations.mean.resize(net->Time());  expectations.var.resize(net->Time());  CheckParent(0, REALV_MV);  CheckParent(1, REALV_ME);  for (size_t i = 0; i < net->Time(); i++) {    myval.mean[i] = 0.0;    myval.var[i] = 1.0;  }  UpdateExpectations();  MyUpdate();}void RectifiedGaussianV::Save(NetSaver *saver){  saver->SetNamedDVSet("myval", myval);  saver->SetNamedDVSet("expectations", expectations);  saver->SetNamedDouble("cost", cost);  Variable::Save(saver);}string RectifiedGaussianV::GetType() const {   return "RectifiedGaussianV"; }double RectifiedGaussianV::Cost(){  if (children.empty()) {    return 0;  }  if (!costuptodate) {    DVH mpar, vpar;    ParRealV(0, mpar, DFlags(true, true));    ParRealV(1, vpar, DFlags(true, false, true));    double c = 0;    for (size_t i = 0; i < net->Time(); i++) {      /* C_p */      c += 0.5 * (vpar.Exp(i) * (Sqr(expectations.mean[i] - mpar.Mean(i))				 + expectations.var[i] + mpar.Var(i))		  - vpar.Mean(i) + log(PI/2)); // + log(2)#if RECTIFIED_BETTER_APPROX      c += log(Erfc(-1/sqrt(2.0) * mpar.Mean(i) * 		    exp(vpar.Mean(i) / 2 + vpar.Var(i) / 8)));#endif      /* C_q */      c += -1/(2*myval.var[i]) *	(expectations.var[i] + Sqr(expectations.mean[i] - myval.mean[i])) +	0.5 * log(2/(PI*myval.var[i])) - 	log(Erfc(-myval.mean[i] / sqrt(2*myval.var[i])));    }    cost = c;    costuptodate = true;  }  return cost;}bool RectifiedGaussianV::GetRealV(DVH &val, DFlags req){  val.vec = &expectations;  if (req.ex) {    return false;  } else {    return true;  }}bool RectifiedGaussianV::GetMyvalV(DVH &val){  val.vec = &myval;  return true;}void RectifiedGaussianV::GradRealV(DVSet &val, const Node *ptr){  if (!clamped && children.empty())    return;  if (ParIdentity(ptr) == 1) { // Variance parent    DVH p0;    ParRealV(0, p0, DFlags(true, true));    val.mean.resize(net->Time());    val.ex.resize(net->Time());    for (size_t i = 0; i < net->Time(); i++) {      val.mean[i] -= 0.5;      val.ex[i] += (Sqr(expectations.mean[i] - p0.Mean(i)) + p0.Var(i) +		    expectations.var[i]) / 2;    }  } else if (ParIdentity(ptr) == 0) { // Mean parent    DVH p0, p1;    ParRealV(0, p0, DFlags(true));    ParRealV(1, p1, DFlags(false, false, true));    val.mean.resize(net->Time());    val.var.resize(net->Time());    for (size_t i = 0; i < net->Time(); i++) {      val.mean[i] += (p0.Mean(i) - expectations.mean[i]) * p1.Exp(i);      val.var[i] += p1.Exp(i) / 2;#if RECTIFIED_BETTER_APPROX      val.mean[i] += sqrt(2 / PI)	* exp(-0.5 * exp(p1.Mean(i) + p1.Var(i) / 4) * Sqr(p0.Mean(i)) 	      + p1.Mean(i) / 2 + p1.Var(i) / 8)	/ Erfc(-1/sqrt(2.0) * exp(p1.Mean(i)/2 + p1.Var(i)/8) * p0.Mean(i));#endif    }  } else {    BBASSERT2(0);  }}void RectifiedGaussianV::GradReal(DSSet &val, const Node *ptr){  if (!clamped && children.empty())    return;  if (ParIdentity(ptr) == 1) { // Variance parent    DVH p0;    ParRealV(0, p0, DFlags(true, true));    for (size_t i = 0; i < net->Time(); i++) {      val.ex += (Sqr(expectations.mean[i] - p0.Mean(i)) + p0.Var(i) +		 expectations.var[i]) / 2;    }    val.mean -= 0.5 * net->Time();  } else if (ParIdentity(ptr) == 0) {  // Mean parent    DVH p0, p1;    ParRealV(0, p0, DFlags(true));    ParRealV(1, p1, DFlags(false, false, true));    for (size_t i = 0; i < net->Time(); i++) {      val.mean += (p0.Mean(i) - expectations.mean[i]) * p1.Exp(i);      val.var += p1.Exp(i) / 2;#if RECTIFIED_BETTER_APPROX      val.mean += sqrt(2 / PI)	* exp(-0.5 * exp(p1.Mean(i) + p1.Var(i) / 4) * Sqr(p0.Mean(i)) 	      + p1.Mean(i) / 2 + p1.Var(i) / 8)	/ Erfc(-1/sqrt(2.0) * exp(p1.Mean(i)/2 + p1.Var(i)/8) * p0.Mean(i));#endif    }  } else {    BBASSERT2(0);  }}void RectifiedGaussianV::MyUpdate(){  if (NumChildren() < 1) {    return;  }  DVSet grad;  ChildGradRealV(grad);  DVH mpar, vpar;  ParRealV(0, mpar, DFlags(true));  ParRealV(1, vpar, DFlags(false, false, true));  double gm = 0;  double gv = 0;  bool hasgradient = grad.mean.size() > 0;  for (size_t i = 0; i < net->Time(); i++) {    if (hasgradient) {      gm = grad.mean[i];      gv = grad.var[i];    }    myval.var[i] = 1 / (2*gv + vpar.Exp(i));    myval.mean[i] = myval.var[i] * (2*gv * expectations.mean[i] - gm				    + vpar.Exp(i) * mpar.Mean(i));  }  UpdateExpectations();  costuptodate = false;}void RectifiedGaussianV::UpdateExpectations(){  double s;  for (size_t i = 0; i < net->Time(); i++) {    if ((myval.mean[i] / sqrt(myval.var[i])) > RECTLIMIT) {      s = sqrt(2 * myval.var[i] / PI) 	/ Erfcx(-myval.mean[i] / sqrt(2 * myval.var[i]));      expectations.mean[i] = myval.mean[i] + s;      expectations.var[i] = Sqr(myval.mean[i]) + myval.var[i] 	+ s * myval.mean[i] - Sqr(expectations.mean[i]);    } else {      expectations.mean[i] = - myval.var[i] / myval.mean[i];      expectations.var[i] = Sqr(expectations.mean[i]);    }  }}/* Class GaussRectV */GaussRectV::GaussRectV(Net *net, Label label, Node *m, Node *v) :  Variable(net, label, m, v), BiParNode(m, v){  cost = 0.0;  posval.mean.resize(net->Time());   posval.var.resize(net->Time());  negval.mean.resize(net->Time());   negval.var.resize(net->Time());  posweights.resize(net->Time());  negweights.resize(net->Time());  for (size_t i = 0; i < net->Time(); i++) {    posval.mean[i] = 0.1;    posval.var[i] = 1.0;    negval.mean[i] = 0.1;    negval.var[i] = 1.0;    posweights[i] = 1.0;    negweights[i] = 1.0;  }  posmoments.resize(3);  negmoments.resize(3);  for (int i = 0; i < 3; i++) {    posmoments[i].resize(net->Time());    negmoments[i].resize(net->Time());  }  expts.mean.resize(net->Time());  expts.var.resize(net->Time());  rectexpts.mean.resize(net->Time());  rectexpts.var.resize(net->Time());  CheckParent(0, REALV_MV);  CheckParent(1, REALV_ME);//  MyUpdate();  UpdateMoments();  UpdateExpectations();}void GaussRectV::GetState(DV *state, size_t t){  BBASSERT2(t < net->Time());    state->resize(6);  (*state)[0] = posval.mean[t];  (*state)[1] = posval.var[t];  (*state)[2] = negval.mean[t];  (*state)[3] = negval.var[t];  (*state)[4] = posweights[t];  (*state)[5] = negweights[t];}void GaussRectV::SetState(DV *state, size_t t){  BBASSERT2(t < net->Time());  BBASSERT2(state->size() == 6);  posval.mean[t] = (*state)[0];  posval.var[t] = (*state)[1];  negval.mean[t] = (*state)[2];  negval.var[t] = (*state)[3];  posweights[t] = (*state)[4];  negweights[t] = (*state)[5];  UpdateMoments(); // change one, update all =)  UpdateExpectations();  costuptodate = false;  OutdateChild();}void GaussRectV::Save(NetSaver *saver){  saver->SetNamedDVSet("posval", posval);  saver->SetNamedDVSet("negval", negval);  saver->SetNamedDV("posweights", posweights);  saver->SetNamedDV("negweights", negweights);  saver->SetNamedDouble("cost", cost);  Variable::Save(saver);}string GaussRectV::GetType() const {   return "GaussRectV";}void GaussRectV::UpdateMoments(){  for (size_t t = 0; t < net->Time(); t++) {    double sp = Erfc(-posval.mean[t] / sqrt(2*posval.var[t]));    double a = 0.5 * posweights[t];    double b = sqrt(2*posval.var[t] / PI)       / exp(Sqr(posval.mean[t]) / (2 * posval.var[t]));    BBASSERT2(finite(sp));    BBASSERT2(finite(a));    BBASSERT2(finite(b));    posmoments[0][t] = a * sp;    posmoments[1][t] = a * (sp * posval.mean[t] + b);    posmoments[2][t] = a * (sp * (Sqr(posval.mean[t]) + posval.var[t])			    + b * posval.mean[t]);    BBASSERT2(finite(posmoments[0][t]));    BBASSERT2(finite(posmoments[1][t]));    BBASSERT2(finite(posmoments[2][t]));    double sn = Erfc(negval.mean[t] / sqrt(2 * negval.var[t]));    a = 0.5 * negweights[t];    b = sqrt(2*negval.var[t] / PI)       / exp(Sqr(negval.mean[t]) / (2 * negval.var[t]));    negmoments[0][t] = a * sn;    negmoments[1][t] = a * (sn * negval.mean[t] - b);    negmoments[2][t] = a * (sn * (Sqr(negval.mean[t]) + negval.var[t])			    - b * negval.mean[t]);    BBASSERT2(finite(negmoments[0][t]));    BBASSERT2(finite(negmoments[1][t]));    BBASSERT2(finite(negmoments[2][t]));  }}void GaussRectV::UpdateExpectations(){  for (size_t t = 0; t < net->Time(); t++) {    expts.mean[t] = posmoments[1][t] + negmoments[1][t];    expts.var[t] = posmoments[2][t] + negmoments[2][t] - Sqr(expts.mean[t]);    rectexpts.mean[t] = posmoments[1][t];    rectexpts.var[t] = posmoments[2][t] - Sqr(rectexpts.mean[t]);  }}double GaussRectV::Cost(){  if (children.empty()) {    return 0;  }  if (!costuptodate) {    DVH mpar, vpar;    ParRealV(0, mpar, DFlags(true, true));    ParRealV(1, vpar, DFlags(true, false, true));    double c = 0;    double S;    for (size_t i = 0; i < net->Time(); i++) {      /* C_p */      c += 0.5 * (vpar.Exp(i) * (Sqr(expts.mean[i] - mpar.Mean(i)) 				 + expts.var[i] + mpar.Var(i))		  - vpar.Mean(i) + log(2*PI));//       if (!finite(c)) {// 	cout << "expts: mean = " << expts.mean[i] << ", "// 	     << "var = " << expts.var[i] << endl;// 	cout << "mpar: mean = " << mpar.Mean(i) << ", "// 	     << "var = " << mpar.Var(i) << endl;// 	cout << "vpar: mean = " << vpar.Mean(i) << ", "// 	     << "exp = " << vpar.Exp(i) << endl;//       }      BBASSERT2(finite(c));      /* C_q^+ */      S = Erfc(-posval.mean[i] / sqrt(2 * posval.var[i]));      if (posweights[i] > EPSILON) {	c += S/2 * (posweights[i] * log(posweights[i])		    - posweights[i]/2 * log(2*PI*posval.var[i]));	BBASSERT2(finite(c));      } else {	c -= S/4 * posweights[i] * log(2*PI*posval.var[i]);	BBASSERT2(finite(c));      }      c += - Sqr(posval.mean[i]) / (2 * posval.var[i]) * posmoments[0][i]	+ posval.mean[i] / posval.var[i] * posmoments[1][i]	- posmoments[2][i] / (2*posval.var[i]);      BBASSERT2(finite(c));      /* C_q^- */      S = Erfc(negval.mean[i] / sqrt(2 * negval.var[i]));      if (negweights[i] > EPSILON) {	c += S/2 * (negweights[i] * log(negweights[i])		    - negweights[i]/2 * log(2*PI*negval.var[i]));	BBASSERT2(finite(c));      } else {	c -= S/4 * negweights[i] * log(2*PI*negval.var[i]);	BBASSERT2(finite(c));      }      c += - Sqr(negval.mean[i]) / (2 * negval.var[i]) * negmoments[0][i]	+ negval.mean[i] / negval.var[i] * negmoments[1][i]	- negmoments[2][i] / (2*negval.var[i]);      BBASSERT2(finite(c));    }    cost = c;    costuptodate = true;  }  BBASSERT2(finite(cost));  return cost;}bool GaussRectV::GetRealV(DVH &val, DFlags req){  val.vec = &expts;  return !req.ex;}bool GaussRectV::GetRectRealV(DVH &val, DFlags req){  val.vec = &rectexpts;  return !req.ex;}void GaussRectV::GradRealV(DVSet &val, const Node *ptr){  if (!clamped && children.empty()) {    return;  }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -