⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 node.cc

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 CC
📖 第 1 页 / 共 5 页
字号:
  if (ParIdentity(ptr) == 1) { // Variance parent    DVH p0;    ParRealV(0, p0, DFlags(true, true));    val.mean.resize(net->Time());    val.ex.resize(net->Time());    for (size_t i = 0; i < net->Time(); i++) {      val.mean[i] -= 0.5;      val.ex[i] += (Sqr(expts.mean[i] - p0.Mean(i)) + p0.Var(i) 		    + expts.var[i]) / 2;    }  } else if (ParIdentity(ptr) == 0) { // Mean parent    DVH p0, p1;    ParRealV(0, p0, DFlags(true));    ParRealV(1, p1, DFlags(false, false, true));    val.mean.resize(net->Time());    val.var.resize(net->Time());    for (size_t i = 0; i < net->Time(); i++) {      val.mean[i] += (p0.Mean(i) - expts.mean[i]) * p1.Exp(i);      val.var[i] += p1.Exp(i) / 2;    }  } else {    BBASSERT2(false);  }}void GaussRectV::GradReal(DSSet &val, const Node *ptr){  if (!clamped && children.empty())    return;  if (ParIdentity(ptr) == 1) { // Variance parent    DVH p0;    ParRealV(0, p0, DFlags(true, true));    for (size_t i = 0; i < net->Time(); i++) {      val.ex += (Sqr(expts.mean[i] - p0.Mean(i)) + p0.Var(i) 		 + expts.var[i]) / 2;    }    val.mean -= 0.5 * net->Time();  } else if (ParIdentity(ptr) == 0) {  // Mean parent    DVH p0, p1;    ParRealV(0, p0, DFlags(true));    ParRealV(1, p1, DFlags(false, false, true));    for (size_t i = 0; i < net->Time(); i++) {      val.mean += (p0.Mean(i) - expts.mean[i]) * p1.Exp(i);      val.var += p1.Exp(i) / 2;    }  } else {    BBASSERT2(false);  }}void GaussRectV::MyPartialUpdate(IntV *indices){  if (NumChildren() < 1) {    return;  }  DVSet ng; // gradient from direct children  DVSet rg; // gradient from children below the rectification  ChildGradients(ng, rg);  if (rg.mean.size() == 0) {    return;  }  DVH mpar, vpar;  ParRealV(0, mpar, DFlags(true));  ParRealV(1, vpar, DFlags(false, false, true));  double x, vx, ivx, z;  bool limitexceded = false;  for (size_t j = 0; j < indices->size(); j++) {    int i = (*indices)[j];    if (ng.mean.size() == 0) {      negval.var[i] = 1 / vpar.Exp(i);      negval.mean[i] = mpar.Mean(i);    } else {      negval.var[i] = 1 / (vpar.Exp(i) + 2 * ng.var[i]);      negval.mean[i] = negval.var[i] * 	(vpar.Exp(i) * mpar.Mean(i)	 + 2 * ng.var[i] * expts.mean[i] - ng.mean[i]);    }    ivx = 2 * rg.var[i];    vx = 1 / ivx;    x = rectexpts.mean[i] - vx * rg.mean[i];    posval.var[i] = 1 / (ivx + 1 / negval.var[i]);    posval.mean[i] = posval.var[i] *       (ivx * x + negval.mean[i] / negval.var[i]);    posweights[i] = NormPdf(x, negval.mean[i], vx + negval.var[i]);    negweights[i] = NormPdf(x, 0, vx);    z = 0.5 * posweights[i] * Erfc(-posval.mean[i] / sqrt(2 * posval.var[i]))      + 0.5 * negweights[i] * Erfc(negval.mean[i] / sqrt(2 * negval.var[i]));    if (z > GAUSSRECTLIMIT) {      z = 1/z;    } else {      limitexceded = true;      z = 1 / GAUSSRECTLIMIT;    }    //     if (!finite(z)) {//       cout << "posval: mean = " << posval.mean[i] << ", "// 	   << "var = " << posval.var[i] << endl;//       cout << "negval: mean = " << negval.mean[i] << ", "// 	   << "var = " << negval.var[i] << endl;//     }    BBASSERT2(finite(z));    posweights[i] *= z;    negweights[i] *= z;    BBASSERT2(finite(posweights[i]));    BBASSERT2(finite(negweights[i]));  }  if (limitexceded) {    cout << "Warning: Limit exceded in " << GetLabel() << endl;  }  UpdateMoments();  UpdateExpectations();  costuptodate = false;}void GaussRectV::MyUpdate(){  if (NumChildren() < 1) {    return;  }  DVSet ng; // gradient from direct children  DVSet rg; // gradient from children below the rectification  ChildGradients(ng, rg);  if (rg.mean.size() == 0) {    return;  }  DVH mpar, vpar;  ParRealV(0, mpar, DFlags(true));  ParRealV(1, vpar, DFlags(false, false, true));  double x, vx, ivx, z;  bool limitexceded = false;  for (size_t i = 0; i < net->Time(); i++) {    if (ng.mean.size() == 0) {      negval.var[i] = 1 / vpar.Exp(i);      negval.mean[i] = mpar.Mean(i);    } else {      negval.var[i] = 1 / (vpar.Exp(i) + 2 * ng.var[i]);      negval.mean[i] = negval.var[i] * 	(vpar.Exp(i) * mpar.Mean(i)	 + 2 * ng.var[i] * expts.mean[i] - ng.mean[i]);    }    ivx = 2 * rg.var[i];    vx = 1 / ivx;    x = rectexpts.mean[i] - vx * rg.mean[i];    posval.var[i] = 1 / (ivx + 1 / negval.var[i]);    posval.mean[i] = posval.var[i] *       (ivx * x + negval.mean[i] / negval.var[i]);    posweights[i] = NormPdf(x, negval.mean[i], vx + negval.var[i]);    negweights[i] = NormPdf(x, 0, vx);    z = 0.5 * posweights[i] * Erfc(-posval.mean[i] / sqrt(2 * posval.var[i]))      + 0.5 * negweights[i] * Erfc(negval.mean[i] / sqrt(2 * negval.var[i]));    if (z > GAUSSRECTLIMIT) {      z = 1/z;    } else {      limitexceded = true;      z = 1 / GAUSSRECTLIMIT;    }    //     if (!finite(z)) {//       cout << "posval: mean = " << posval.mean[i] << ", "// 	   << "var = " << posval.var[i] << endl;//       cout << "negval: mean = " << negval.mean[i] << ", "// 	   << "var = " << negval.var[i] << endl;//     }    BBASSERT2(finite(z));    posweights[i] *= z;    negweights[i] *= z;    BBASSERT2(finite(posweights[i]));    BBASSERT2(finite(negweights[i]));  }  if (limitexceded) {    cout << "Warning: Limit exceded in " << GetLabel() << endl;  }  UpdateMoments();  UpdateExpectations();  costuptodate = false;}void GaussRectV::ChildGradients(DVSet &norm, DVSet &rect){  for (size_t i = 0; i < children.size(); i++) {    RectificationV *rnode = dynamic_cast<RectificationV *>(children[i]);    if (rnode == 0) { // not the rectification node      children[i]->GradRealV(norm, this);    } else { // yes, this is the rectification node      rnode->GradRealV(rect, this);    }  }}/* Class GaussRect */GaussRect::GaussRect(Net *net, Label label, Node *m, Node *v) :  Variable(net, label, m, v), BiParNode(m, v){  cost = 0.0;  posval.mean = 0.1;  posval.var = 1.0;  negval.mean = 0.1;  negval.var = 1.0;  posweight = 1.0;  negweight = 1.0;  posmoments.resize(3);  negmoments.resize(3);  CheckParent(0, REAL_MV);  CheckParent(1, REAL_ME);  UpdateMoments();  UpdateExpectations();}void GaussRect::Save(NetSaver *saver){  saver->SetNamedDSSet("posval", posval);  saver->SetNamedDSSet("negval", negval);  saver->SetNamedDouble("posweight", posweight);  saver->SetNamedDouble("negweight", negweight);  saver->SetNamedDouble("cost", cost);  Variable::Save(saver);}string GaussRect::GetType() const {   return "GaussRect";}void GaussRect::UpdateMoments(){  double sp = Erfc(-posval.mean / sqrt(2*posval.var));  double a = 0.5 * posweight;  double b = sqrt(2*posval.var / PI)    / exp(Sqr(posval.mean) / (2 * posval.var));  posmoments[0] = a * sp;  posmoments[1] = a * (sp * posval.mean + b);  posmoments[2] = a * (sp * (Sqr(posval.mean) + posval.var)		       + b * posval.mean);  double sn = Erfc(negval.mean / sqrt(2 * negval.var));  a = 0.5 * negweight;  b = sqrt(2*negval.var / PI)     / exp(Sqr(negval.mean) / (2 * negval.var));  negmoments[0] = a * sn;  negmoments[1] = a * (sn * negval.mean - b);  negmoments[2] = a * (sn * (Sqr(negval.mean) + negval.var)		       - b * negval.mean);}void GaussRect::UpdateExpectations(){  expts.mean = posmoments[1] + negmoments[1];  expts.var = posmoments[2] + negmoments[2] - Sqr(expts.mean);  rectexpts.mean = posmoments[1];  rectexpts.var = posmoments[2] - Sqr(rectexpts.mean);}double GaussRect::Cost(){  if (children.empty()) {    return 0;  }  if (!costuptodate) {    DSSet mpar, vpar;    ParReal(0, mpar, DFlags(true, true));    ParReal(1, vpar, DFlags(true, false, true));    double c = 0;    double S;    /* C_p */    c += 0.5 * (vpar.ex * (Sqr(expts.mean - mpar.mean) 			   + expts.var + mpar.var)		- vpar.mean + log(2*PI));    /* C_q^+ */    S = Erfc(-posval.mean / sqrt(2 * posval.var));    if (posweight > EPSILON) {      c += S/2 * (posweight * log(posweight)		  - posweight/2 * log(2*PI*posval.var));    } else {      c -= S/4 * posweight * log(2*PI*posval.var);    }    c += - Sqr(posval.mean) / (2 * posval.var) * posmoments[0]      + posval.mean / posval.var * posmoments[1]      - posmoments[2] / (2*posval.var);        /* C_q^- */    S = Erfc(negval.mean / sqrt(2 * negval.var));    if (negweight > EPSILON) {      c += S/2 * (negweight * log(negweight)		  - negweight/2 * log(2*PI*negval.var));    } else {      c -= S/4 * negweight * log(2*PI*negval.var);    }    c += - Sqr(negval.mean) / (2 * negval.var) * negmoments[0]      + negval.mean / negval.var * negmoments[1]      - negmoments[2] / (2*negval.var);    cost = c;    costuptodate = true;    BBASSERT2(finite(cost));  }  return cost;}bool GaussRect::GetReal(DSSet &val, DFlags req){  val = expts;  return !req.ex;}bool GaussRect::GetRectReal(DSSet &val, DFlags req){  val = rectexpts;  return !req.ex;}void GaussRect::GradReal(DSSet &val, const Node *ptr){  if (!clamped && children.empty())    return;  if (ParIdentity(ptr) == 1) { // Variance parent    DSSet p0;    ParReal(0, p0, DFlags(true, true));    val.ex += (Sqr(expts.mean - p0.mean) + p0.var + expts.var) / 2;    val.mean -= 0.5;  } else if (ParIdentity(ptr) == 0) {  // Mean parent    DSSet p0, p1;    ParReal(0, p0, DFlags(true));    ParReal(1, p1, DFlags(false, false, true));    val.mean += (p0.mean - expts.mean) * p1.ex;    val.var += p1.ex / 2;  } else {    BBASSERT2(false);  }}void GaussRect::MyUpdate(){  if (NumChildren() < 1) {    return;  }  DSSet ng; // gradient from direct children  DSSet rg; // gradient from children below the rectification  ChildGradients(ng, rg);  DSSet mpar, vpar;  ParReal(0, mpar, DFlags(true));  ParReal(1, vpar, DFlags(false, false, true));  double x, vx, ivx, z;  bool limitexceded = false;  negval.var = 1 / (vpar.ex + 2 * ng.var);  negval.mean = negval.var *     (vpar.ex * mpar.mean + 2 * ng.var * expts.mean - ng.mean);  ivx = 2 * rg.var;  vx = 1 / ivx;  x = rectexpts.mean - vx * rg.mean;  posval.var = 1 / (ivx + 1 / negval.var);  posval.mean = posval.var * (ivx * x + negval.mean / negval.var);  posweight = NormPdf(x, negval.mean, vx + negval.var);  negweight = NormPdf(x, 0, vx);  z = 0.5 * posweight * Erfc(-posval.mean / sqrt(2 * posval.var))    + 0.5 * negweight * Erfc(negval.mean / sqrt(2 * negval.var));  if (z > GAUSSRECTLIMIT) {    z = 1/z;  } else {    limitexceded = true;    z = 1 / GAUSSRECTLIMIT;  }    BBASSERT2(finite(z));  posweight *= z;  negweight *= z;  BBASSERT2(finite(posweight));  BBASSERT2(finite(negweight));  if (limitexceded) {    cout << "Warning: Limit exceded in " << GetLabel() << endl;  }  UpdateMoments();  UpdateExpectations();  costuptodate = false;}void GaussRect::ChildGradients(DSSet &norm, DSSet &rect){  for (size_t i = 0; i < children.size(); i++) {    Rectification *rnode = dynamic_cast<Rectification *>(children[i]);    if (rnode == 0) { // not the rectification node      children[i]->GradReal(norm, this);    } else { // yes, this is the rectification node      rnode->GradReal(rect, this);    }  }}/* class GaussRectVState */GaussRectVState::GaussRectVState(GaussRectV *n){  node = n;}DVSet &GaussRectVState::GetPosVal(){  return node->posval;}DVSet &GaussRectVState::GetNegVal(){  return node->negval;}DV &GaussRectVState::GetPosWeights(){  return node->posweights;}DV &GaussRectVState::GetNegWeights(){  return node->negweights;}DV &GaussRectVState::GetPosMoment(int i){  return node->posmoments[i];}DV &GaussRectVState::GetNegMoment(int i){  return node->negmoments[i];}  inline void ScaleWeights(DV *w){  //cout << "ScaleWeights begin" << endl;  double sum = 0.0;  for (size_t i = 0; i < w->size(); i++) {    sum += (*w)[i];  }  BBASSERT(sum != 0.0);  for (size_t i = 0; i < w->size(); i++) {    double val = (*w)[i] / sum;    (*w)[i] = val > CATEGLIMIT ? val : 0.0;  }  //cout << "ScaleWeights end" << endl;}/* class MoGV */MoGV::MoGV(Net *net, Label label, Node *d)  : Variable(net, label, d), NParNode(d){  cost = 0;  CheckParent(0, DISCRETEV);  numComponents = NumComponents();  expts.mean.resize(net->Time());  expts.var.resize(net->Time());}double MoGV::Cost(){  if (!clamped && children.empty())    return 0;  if (!costuptodate) {    DVH mpar, vpar;    DFlags mflags(true, true);    DFlags vflags(true, false, true);    VDDH dpar;    ParDiscreteV(0, dpar);    double c = 0;    cost = 0;    double log2pi = log(2*PI);        for (size_t k = 0; k < numComponents; k++) {      means[k]->GetRealV(mpar, mflags);      vars[k]->GetRealV(vpar, vflags);      for (size_t i = 0; i < net->Time(); i++) {	/* C_p */	c = vpar.Exp(i) * (Sqr(myval[k]->mean[i] - mpar.Mean(i)) 			   + mpar.Var(i) + myval[k]->var[i])	  - vpar.Mean(i) + log2pi;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -