📄 node.cc
字号:
bool RectifiedGaussian::GetMyval(DSSet &val){ val.mean = myval.mean; val.var = myval.var; return true;}bool RectifiedGaussian::GetReal(DSSet &val, DFlags req){ if (req.mean) { val.mean = expectations.mean; } if (req.var) { val.var = expectations.var; } /* Clients hoping for ex are out of luck. */ if (req.ex) { return false; } else { return true; }}void RectifiedGaussian::GradReal(DSSet &val, const Node *ptr){ if (!clamped && children.empty()) return; if (ParIdentity(ptr) == 1) { // Variance parent DSSet p0; ParReal(0, p0, DFlags(true, true)); val.mean -= 0.5; val.ex += 0.5 * (Sqr(expectations.mean - p0.mean) + expectations.var + p0.var); } else if (ParIdentity(ptr) == 0) { // Mean parent DSSet p0, p1; ParReal(0, p0, DFlags(true)); ParReal(1, p1, DFlags(false, false, true)); val.mean += p1.ex * (expectations.mean - p0.mean); val.var += 0.5 * p1.ex;#if RECTIFIED_BETTER_APPROX val.mean += sqrt(2 / PI) * exp(-0.5 * exp(p1.mean + p1.var / 4) * Sqr(p0.mean) + p1.mean / 2 + p1.var / 8) / Erfc(-1 / sqrt(2.0) * exp(p1.mean / 2 + p1.var / 8) * p0.mean);#endif } else { BBASSERT2(false); }}void RectifiedGaussian::MyPartialUpdate(IntV *indices){ MyUpdate();}void RectifiedGaussian::MyUpdate(){ DSSet grad, mpar, vpar;; /* double linvvar, lmean; */ /* From the gradient we can deduce the likelihood and since the likelihood is known to be Gaussian and the prior is Rectified Gaussian the Rectified Gaussian posterior approximation matches exactly to the correct posterior and the parameters can be set according to the correct posterior. */ ChildGradReal(grad); ParReal(0, mpar, DFlags(true)); ParReal(1, vpar, DFlags(false, false, true)); /* This is what happens below. linvvar = 2 * grad.var; lmean = expectations.mean - grad.mean / linvvar; myval.var = 1 / (linvvar + vpar.ex); myval.mean = myval.var * (linvvar * lmean + vpar.ex * mpar.mean); */ myval.var = 1 / (2*grad.var + vpar.ex); myval.mean = myval.var * (2*grad.var * expectations.mean - grad.mean + vpar.ex * mpar.mean); /* myvals have changed, expectations needs to be updated */ UpdateExpectations(); costuptodate = false;}void RectifiedGaussian::UpdateExpectations(){ if ((myval.mean / sqrt(myval.var)) > RECTLIMIT) { double scale; scale = sqrt(2*myval.var/PI) / Erfcx(-myval.mean / sqrt(2*myval.var)); expectations.mean = myval.mean + scale; expectations.var = Sqr(myval.mean) + myval.var + scale * myval.mean - Sqr(expectations.mean); } else { /* use exponential approximation */ expectations.mean = -myval.var / myval.mean; expectations.var = Sqr(expectations.mean); }}/* class RectifiedGaussianV */RectifiedGaussianV::RectifiedGaussianV(Net *net, Label label, Node *m, Node *v) : Variable(net, label, m, v), BiParNode(m, v){ cost = 0; myval.mean.resize(net->Time()); myval.var.resize(net->Time()); expectations.mean.resize(net->Time()); expectations.var.resize(net->Time()); CheckParent(0, REALV_MV); CheckParent(1, REALV_ME); for (size_t i = 0; i < net->Time(); i++) { myval.mean[i] = 0.0; myval.var[i] = 1.0; } UpdateExpectations(); MyUpdate();}void RectifiedGaussianV::Save(NetSaver *saver){ saver->SetNamedDVSet("myval", myval); saver->SetNamedDVSet("expectations", expectations); saver->SetNamedDouble("cost", cost); Variable::Save(saver);}string RectifiedGaussianV::GetType() const { return "RectifiedGaussianV"; }double RectifiedGaussianV::Cost(){ if (children.empty()) { return 0; } if (!costuptodate) { DVH mpar, vpar; ParRealV(0, mpar, DFlags(true, true)); ParRealV(1, vpar, DFlags(true, false, true)); double c = 0; for (size_t i = 0; i < net->Time(); i++) { /* C_p */ c += 0.5 * (vpar.Exp(i) * (Sqr(expectations.mean[i] - mpar.Mean(i)) + expectations.var[i] + mpar.Var(i)) - vpar.Mean(i) + log(PI/2)); // + log(2)#if RECTIFIED_BETTER_APPROX c += log(Erfc(-1/sqrt(2.0) * mpar.Mean(i) * exp(vpar.Mean(i) / 2 + vpar.Var(i) / 8)));#endif /* C_q */ c += -1/(2*myval.var[i]) * (expectations.var[i] + Sqr(expectations.mean[i] - myval.mean[i])) + 0.5 * log(2/(PI*myval.var[i])) - log(Erfc(-myval.mean[i] / sqrt(2*myval.var[i]))); } cost = c; costuptodate = true; } return cost;}bool RectifiedGaussianV::GetRealV(DVH &val, DFlags req){ val.vec = &expectations; if (req.ex) { return false; } else { return true; }}bool RectifiedGaussianV::GetMyvalV(DVH &val){ val.vec = &myval; return true;}void RectifiedGaussianV::GradRealV(DVSet &val, const Node *ptr){ if (!clamped && children.empty()) return; if (ParIdentity(ptr) == 1) { // Variance parent DVH p0; ParRealV(0, p0, DFlags(true, true)); val.mean.resize(net->Time()); val.ex.resize(net->Time()); for (size_t i = 0; i < net->Time(); i++) { val.mean[i] -= 0.5; val.ex[i] += (Sqr(expectations.mean[i] - p0.Mean(i)) + p0.Var(i) + expectations.var[i]) / 2; } } else if (ParIdentity(ptr) == 0) { // Mean parent DVH p0, p1; ParRealV(0, p0, DFlags(true)); ParRealV(1, p1, DFlags(false, false, true)); val.mean.resize(net->Time()); val.var.resize(net->Time()); for (size_t i = 0; i < net->Time(); i++) { val.mean[i] += (p0.Mean(i) - expectations.mean[i]) * p1.Exp(i); val.var[i] += p1.Exp(i) / 2;#if RECTIFIED_BETTER_APPROX val.mean[i] += sqrt(2 / PI) * exp(-0.5 * exp(p1.Mean(i) + p1.Var(i) / 4) * Sqr(p0.Mean(i)) + p1.Mean(i) / 2 + p1.Var(i) / 8) / Erfc(-1/sqrt(2.0) * exp(p1.Mean(i)/2 + p1.Var(i)/8) * p0.Mean(i));#endif } } else { BBASSERT2(0); }}void RectifiedGaussianV::GradReal(DSSet &val, const Node *ptr){ if (!clamped && children.empty()) return; if (ParIdentity(ptr) == 1) { // Variance parent DVH p0; ParRealV(0, p0, DFlags(true, true)); for (size_t i = 0; i < net->Time(); i++) { val.ex += (Sqr(expectations.mean[i] - p0.Mean(i)) + p0.Var(i) + expectations.var[i]) / 2; } val.mean -= 0.5 * net->Time(); } else if (ParIdentity(ptr) == 0) { // Mean parent DVH p0, p1; ParRealV(0, p0, DFlags(true)); ParRealV(1, p1, DFlags(false, false, true)); for (size_t i = 0; i < net->Time(); i++) { val.mean += (p0.Mean(i) - expectations.mean[i]) * p1.Exp(i); val.var += p1.Exp(i) / 2;#if RECTIFIED_BETTER_APPROX val.mean += sqrt(2 / PI) * exp(-0.5 * exp(p1.Mean(i) + p1.Var(i) / 4) * Sqr(p0.Mean(i)) + p1.Mean(i) / 2 + p1.Var(i) / 8) / Erfc(-1/sqrt(2.0) * exp(p1.Mean(i)/2 + p1.Var(i)/8) * p0.Mean(i));#endif } } else { BBASSERT2(0); }}void RectifiedGaussianV::MyUpdate(){ if (NumChildren() < 1) { return; } DVSet grad; ChildGradRealV(grad); DVH mpar, vpar; ParRealV(0, mpar, DFlags(true)); ParRealV(1, vpar, DFlags(false, false, true)); double gm = 0; double gv = 0; bool hasgradient = grad.mean.size() > 0; for (size_t i = 0; i < net->Time(); i++) { if (hasgradient) { gm = grad.mean[i]; gv = grad.var[i]; } myval.var[i] = 1 / (2*gv + vpar.Exp(i)); myval.mean[i] = myval.var[i] * (2*gv * expectations.mean[i] - gm + vpar.Exp(i) * mpar.Mean(i)); } UpdateExpectations(); costuptodate = false;}void RectifiedGaussianV::UpdateExpectations(){ double s; for (size_t i = 0; i < net->Time(); i++) { if ((myval.mean[i] / sqrt(myval.var[i])) > RECTLIMIT) { s = sqrt(2 * myval.var[i] / PI) / Erfcx(-myval.mean[i] / sqrt(2 * myval.var[i])); expectations.mean[i] = myval.mean[i] + s; expectations.var[i] = Sqr(myval.mean[i]) + myval.var[i] + s * myval.mean[i] - Sqr(expectations.mean[i]); } else { expectations.mean[i] = - myval.var[i] / myval.mean[i]; expectations.var[i] = Sqr(expectations.mean[i]); } }}/* Class GaussRectV */GaussRectV::GaussRectV(Net *net, Label label, Node *m, Node *v) : Variable(net, label, m, v), BiParNode(m, v){ cost = 0.0; posval.mean.resize(net->Time()); posval.var.resize(net->Time()); negval.mean.resize(net->Time()); negval.var.resize(net->Time()); posweights.resize(net->Time()); negweights.resize(net->Time()); for (size_t i = 0; i < net->Time(); i++) { posval.mean[i] = 0.1; posval.var[i] = 1.0; negval.mean[i] = 0.1; negval.var[i] = 1.0; posweights[i] = 1.0; negweights[i] = 1.0; } posmoments.resize(3); negmoments.resize(3); for (int i = 0; i < 3; i++) { posmoments[i].resize(net->Time()); negmoments[i].resize(net->Time()); } expts.mean.resize(net->Time()); expts.var.resize(net->Time()); rectexpts.mean.resize(net->Time()); rectexpts.var.resize(net->Time()); CheckParent(0, REALV_MV); CheckParent(1, REALV_ME);// MyUpdate(); UpdateMoments(); UpdateExpectations();}void GaussRectV::GetState(DV *state, size_t t){ BBASSERT2(t < net->Time()); state->resize(6); (*state)[0] = posval.mean[t]; (*state)[1] = posval.var[t]; (*state)[2] = negval.mean[t]; (*state)[3] = negval.var[t]; (*state)[4] = posweights[t]; (*state)[5] = negweights[t];}void GaussRectV::SetState(DV *state, size_t t){ BBASSERT2(t < net->Time()); BBASSERT2(state->size() == 6); posval.mean[t] = (*state)[0]; posval.var[t] = (*state)[1]; negval.mean[t] = (*state)[2]; negval.var[t] = (*state)[3]; posweights[t] = (*state)[4]; negweights[t] = (*state)[5]; UpdateMoments(); // change one, update all =) UpdateExpectations(); costuptodate = false; OutdateChild();}void GaussRectV::Save(NetSaver *saver){ saver->SetNamedDVSet("posval", posval); saver->SetNamedDVSet("negval", negval); saver->SetNamedDV("posweights", posweights); saver->SetNamedDV("negweights", negweights); saver->SetNamedDouble("cost", cost); Variable::Save(saver);}string GaussRectV::GetType() const { return "GaussRectV";}void GaussRectV::UpdateMoments(){ for (size_t t = 0; t < net->Time(); t++) { double sp = Erfc(-posval.mean[t] / sqrt(2*posval.var[t])); double a = 0.5 * posweights[t]; double b = sqrt(2*posval.var[t] / PI) / exp(Sqr(posval.mean[t]) / (2 * posval.var[t])); BBASSERT2(finite(sp)); BBASSERT2(finite(a)); BBASSERT2(finite(b)); posmoments[0][t] = a * sp; posmoments[1][t] = a * (sp * posval.mean[t] + b); posmoments[2][t] = a * (sp * (Sqr(posval.mean[t]) + posval.var[t]) + b * posval.mean[t]); BBASSERT2(finite(posmoments[0][t])); BBASSERT2(finite(posmoments[1][t])); BBASSERT2(finite(posmoments[2][t])); double sn = Erfc(negval.mean[t] / sqrt(2 * negval.var[t])); a = 0.5 * negweights[t]; b = sqrt(2*negval.var[t] / PI) / exp(Sqr(negval.mean[t]) / (2 * negval.var[t])); negmoments[0][t] = a * sn; negmoments[1][t] = a * (sn * negval.mean[t] - b); negmoments[2][t] = a * (sn * (Sqr(negval.mean[t]) + negval.var[t]) - b * negval.mean[t]); BBASSERT2(finite(negmoments[0][t])); BBASSERT2(finite(negmoments[1][t])); BBASSERT2(finite(negmoments[2][t])); }}void GaussRectV::UpdateExpectations(){ for (size_t t = 0; t < net->Time(); t++) { expts.mean[t] = posmoments[1][t] + negmoments[1][t]; expts.var[t] = posmoments[2][t] + negmoments[2][t] - Sqr(expts.mean[t]); rectexpts.mean[t] = posmoments[1][t]; rectexpts.var[t] = posmoments[2][t] - Sqr(rectexpts.mean[t]); }}double GaussRectV::Cost(){ if (children.empty()) { return 0; } if (!costuptodate) { DVH mpar, vpar; ParRealV(0, mpar, DFlags(true, true)); ParRealV(1, vpar, DFlags(true, false, true)); double c = 0; double S; for (size_t i = 0; i < net->Time(); i++) { /* C_p */ c += 0.5 * (vpar.Exp(i) * (Sqr(expts.mean[i] - mpar.Mean(i)) + expts.var[i] + mpar.Var(i)) - vpar.Mean(i) + log(2*PI));// if (!finite(c)) {// cout << "expts: mean = " << expts.mean[i] << ", "// << "var = " << expts.var[i] << endl;// cout << "mpar: mean = " << mpar.Mean(i) << ", "// << "var = " << mpar.Var(i) << endl;// cout << "vpar: mean = " << vpar.Mean(i) << ", "// << "exp = " << vpar.Exp(i) << endl;// } BBASSERT2(finite(c)); /* C_q^+ */ S = Erfc(-posval.mean[i] / sqrt(2 * posval.var[i])); if (posweights[i] > EPSILON) { c += S/2 * (posweights[i] * log(posweights[i]) - posweights[i]/2 * log(2*PI*posval.var[i])); BBASSERT2(finite(c)); } else { c -= S/4 * posweights[i] * log(2*PI*posval.var[i]); BBASSERT2(finite(c)); } c += - Sqr(posval.mean[i]) / (2 * posval.var[i]) * posmoments[0][i] + posval.mean[i] / posval.var[i] * posmoments[1][i] - posmoments[2][i] / (2*posval.var[i]); BBASSERT2(finite(c)); /* C_q^- */ S = Erfc(negval.mean[i] / sqrt(2 * negval.var[i])); if (negweights[i] > EPSILON) { c += S/2 * (negweights[i] * log(negweights[i]) - negweights[i]/2 * log(2*PI*negval.var[i])); BBASSERT2(finite(c)); } else { c -= S/4 * negweights[i] * log(2*PI*negval.var[i]); BBASSERT2(finite(c)); } c += - Sqr(negval.mean[i]) / (2 * negval.var[i]) * negmoments[0][i] + negval.mean[i] / negval.var[i] * negmoments[1][i] - negmoments[2][i] / (2*negval.var[i]); BBASSERT2(finite(c)); } cost = c; costuptodate = true; } BBASSERT2(finite(cost)); return cost;}bool GaussRectV::GetRealV(DVH &val, DFlags req){ val.vec = &expts; return !req.ex;}bool GaussRectV::GetRectRealV(DVH &val, DFlags req){ val.vec = &rectexpts; return !req.ex;}void GaussRectV::GradRealV(DVSet &val, const Node *ptr){ if (!clamped && children.empty()) { return; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -