📄 node.cc
字号:
if (ParIdentity(ptr) == 1) { // Variance parent DVH p0; ParRealV(0, p0, DFlags(true, true)); val.mean.resize(net->Time()); val.ex.resize(net->Time()); for (size_t i = 0; i < net->Time(); i++) { val.mean[i] -= 0.5; val.ex[i] += (Sqr(expts.mean[i] - p0.Mean(i)) + p0.Var(i) + expts.var[i]) / 2; } } else if (ParIdentity(ptr) == 0) { // Mean parent DVH p0, p1; ParRealV(0, p0, DFlags(true)); ParRealV(1, p1, DFlags(false, false, true)); val.mean.resize(net->Time()); val.var.resize(net->Time()); for (size_t i = 0; i < net->Time(); i++) { val.mean[i] += (p0.Mean(i) - expts.mean[i]) * p1.Exp(i); val.var[i] += p1.Exp(i) / 2; } } else { BBASSERT2(false); }}void GaussRectV::GradReal(DSSet &val, const Node *ptr){ if (!clamped && children.empty()) return; if (ParIdentity(ptr) == 1) { // Variance parent DVH p0; ParRealV(0, p0, DFlags(true, true)); for (size_t i = 0; i < net->Time(); i++) { val.ex += (Sqr(expts.mean[i] - p0.Mean(i)) + p0.Var(i) + expts.var[i]) / 2; } val.mean -= 0.5 * net->Time(); } else if (ParIdentity(ptr) == 0) { // Mean parent DVH p0, p1; ParRealV(0, p0, DFlags(true)); ParRealV(1, p1, DFlags(false, false, true)); for (size_t i = 0; i < net->Time(); i++) { val.mean += (p0.Mean(i) - expts.mean[i]) * p1.Exp(i); val.var += p1.Exp(i) / 2; } } else { BBASSERT2(false); }}void GaussRectV::MyPartialUpdate(IntV *indices){ if (NumChildren() < 1) { return; } DVSet ng; // gradient from direct children DVSet rg; // gradient from children below the rectification ChildGradients(ng, rg); if (rg.mean.size() == 0) { return; } DVH mpar, vpar; ParRealV(0, mpar, DFlags(true)); ParRealV(1, vpar, DFlags(false, false, true)); double x, vx, ivx, z; bool limitexceded = false; for (size_t j = 0; j < indices->size(); j++) { int i = (*indices)[j]; if (ng.mean.size() == 0) { negval.var[i] = 1 / vpar.Exp(i); negval.mean[i] = mpar.Mean(i); } else { negval.var[i] = 1 / (vpar.Exp(i) + 2 * ng.var[i]); negval.mean[i] = negval.var[i] * (vpar.Exp(i) * mpar.Mean(i) + 2 * ng.var[i] * expts.mean[i] - ng.mean[i]); } ivx = 2 * rg.var[i]; vx = 1 / ivx; x = rectexpts.mean[i] - vx * rg.mean[i]; posval.var[i] = 1 / (ivx + 1 / negval.var[i]); posval.mean[i] = posval.var[i] * (ivx * x + negval.mean[i] / negval.var[i]); posweights[i] = NormPdf(x, negval.mean[i], vx + negval.var[i]); negweights[i] = NormPdf(x, 0, vx); z = 0.5 * posweights[i] * Erfc(-posval.mean[i] / sqrt(2 * posval.var[i])) + 0.5 * negweights[i] * Erfc(negval.mean[i] / sqrt(2 * negval.var[i])); if (z > GAUSSRECTLIMIT) { z = 1/z; } else { limitexceded = true; z = 1 / GAUSSRECTLIMIT; } // if (!finite(z)) {// cout << "posval: mean = " << posval.mean[i] << ", "// << "var = " << posval.var[i] << endl;// cout << "negval: mean = " << negval.mean[i] << ", "// << "var = " << negval.var[i] << endl;// } BBASSERT2(finite(z)); posweights[i] *= z; negweights[i] *= z; BBASSERT2(finite(posweights[i])); BBASSERT2(finite(negweights[i])); } if (limitexceded) { cout << "Warning: Limit exceded in " << GetLabel() << endl; } UpdateMoments(); UpdateExpectations(); costuptodate = false;}void GaussRectV::MyUpdate(){ if (NumChildren() < 1) { return; } DVSet ng; // gradient from direct children DVSet rg; // gradient from children below the rectification ChildGradients(ng, rg); if (rg.mean.size() == 0) { return; } DVH mpar, vpar; ParRealV(0, mpar, DFlags(true)); ParRealV(1, vpar, DFlags(false, false, true)); double x, vx, ivx, z; bool limitexceded = false; for (size_t i = 0; i < net->Time(); i++) { if (ng.mean.size() == 0) { negval.var[i] = 1 / vpar.Exp(i); negval.mean[i] = mpar.Mean(i); } else { negval.var[i] = 1 / (vpar.Exp(i) + 2 * ng.var[i]); negval.mean[i] = negval.var[i] * (vpar.Exp(i) * mpar.Mean(i) + 2 * ng.var[i] * expts.mean[i] - ng.mean[i]); } ivx = 2 * rg.var[i]; vx = 1 / ivx; x = rectexpts.mean[i] - vx * rg.mean[i]; posval.var[i] = 1 / (ivx + 1 / negval.var[i]); posval.mean[i] = posval.var[i] * (ivx * x + negval.mean[i] / negval.var[i]); posweights[i] = NormPdf(x, negval.mean[i], vx + negval.var[i]); negweights[i] = NormPdf(x, 0, vx); z = 0.5 * posweights[i] * Erfc(-posval.mean[i] / sqrt(2 * posval.var[i])) + 0.5 * negweights[i] * Erfc(negval.mean[i] / sqrt(2 * negval.var[i])); if (z > GAUSSRECTLIMIT) { z = 1/z; } else { limitexceded = true; z = 1 / GAUSSRECTLIMIT; } // if (!finite(z)) {// cout << "posval: mean = " << posval.mean[i] << ", "// << "var = " << posval.var[i] << endl;// cout << "negval: mean = " << negval.mean[i] << ", "// << "var = " << negval.var[i] << endl;// } BBASSERT2(finite(z)); posweights[i] *= z; negweights[i] *= z; BBASSERT2(finite(posweights[i])); BBASSERT2(finite(negweights[i])); } if (limitexceded) { cout << "Warning: Limit exceded in " << GetLabel() << endl; } UpdateMoments(); UpdateExpectations(); costuptodate = false;}void GaussRectV::ChildGradients(DVSet &norm, DVSet &rect){ for (size_t i = 0; i < children.size(); i++) { RectificationV *rnode = dynamic_cast<RectificationV *>(children[i]); if (rnode == 0) { // not the rectification node children[i]->GradRealV(norm, this); } else { // yes, this is the rectification node rnode->GradRealV(rect, this); } }}/* Class GaussRect */GaussRect::GaussRect(Net *net, Label label, Node *m, Node *v) : Variable(net, label, m, v), BiParNode(m, v){ cost = 0.0; posval.mean = 0.1; posval.var = 1.0; negval.mean = 0.1; negval.var = 1.0; posweight = 1.0; negweight = 1.0; posmoments.resize(3); negmoments.resize(3); CheckParent(0, REAL_MV); CheckParent(1, REAL_ME); UpdateMoments(); UpdateExpectations();}void GaussRect::Save(NetSaver *saver){ saver->SetNamedDSSet("posval", posval); saver->SetNamedDSSet("negval", negval); saver->SetNamedDouble("posweight", posweight); saver->SetNamedDouble("negweight", negweight); saver->SetNamedDouble("cost", cost); Variable::Save(saver);}string GaussRect::GetType() const { return "GaussRect";}void GaussRect::UpdateMoments(){ double sp = Erfc(-posval.mean / sqrt(2*posval.var)); double a = 0.5 * posweight; double b = sqrt(2*posval.var / PI) / exp(Sqr(posval.mean) / (2 * posval.var)); posmoments[0] = a * sp; posmoments[1] = a * (sp * posval.mean + b); posmoments[2] = a * (sp * (Sqr(posval.mean) + posval.var) + b * posval.mean); double sn = Erfc(negval.mean / sqrt(2 * negval.var)); a = 0.5 * negweight; b = sqrt(2*negval.var / PI) / exp(Sqr(negval.mean) / (2 * negval.var)); negmoments[0] = a * sn; negmoments[1] = a * (sn * negval.mean - b); negmoments[2] = a * (sn * (Sqr(negval.mean) + negval.var) - b * negval.mean);}void GaussRect::UpdateExpectations(){ expts.mean = posmoments[1] + negmoments[1]; expts.var = posmoments[2] + negmoments[2] - Sqr(expts.mean); rectexpts.mean = posmoments[1]; rectexpts.var = posmoments[2] - Sqr(rectexpts.mean);}double GaussRect::Cost(){ if (children.empty()) { return 0; } if (!costuptodate) { DSSet mpar, vpar; ParReal(0, mpar, DFlags(true, true)); ParReal(1, vpar, DFlags(true, false, true)); double c = 0; double S; /* C_p */ c += 0.5 * (vpar.ex * (Sqr(expts.mean - mpar.mean) + expts.var + mpar.var) - vpar.mean + log(2*PI)); /* C_q^+ */ S = Erfc(-posval.mean / sqrt(2 * posval.var)); if (posweight > EPSILON) { c += S/2 * (posweight * log(posweight) - posweight/2 * log(2*PI*posval.var)); } else { c -= S/4 * posweight * log(2*PI*posval.var); } c += - Sqr(posval.mean) / (2 * posval.var) * posmoments[0] + posval.mean / posval.var * posmoments[1] - posmoments[2] / (2*posval.var); /* C_q^- */ S = Erfc(negval.mean / sqrt(2 * negval.var)); if (negweight > EPSILON) { c += S/2 * (negweight * log(negweight) - negweight/2 * log(2*PI*negval.var)); } else { c -= S/4 * negweight * log(2*PI*negval.var); } c += - Sqr(negval.mean) / (2 * negval.var) * negmoments[0] + negval.mean / negval.var * negmoments[1] - negmoments[2] / (2*negval.var); cost = c; costuptodate = true; BBASSERT2(finite(cost)); } return cost;}bool GaussRect::GetReal(DSSet &val, DFlags req){ val = expts; return !req.ex;}bool GaussRect::GetRectReal(DSSet &val, DFlags req){ val = rectexpts; return !req.ex;}void GaussRect::GradReal(DSSet &val, const Node *ptr){ if (!clamped && children.empty()) return; if (ParIdentity(ptr) == 1) { // Variance parent DSSet p0; ParReal(0, p0, DFlags(true, true)); val.ex += (Sqr(expts.mean - p0.mean) + p0.var + expts.var) / 2; val.mean -= 0.5; } else if (ParIdentity(ptr) == 0) { // Mean parent DSSet p0, p1; ParReal(0, p0, DFlags(true)); ParReal(1, p1, DFlags(false, false, true)); val.mean += (p0.mean - expts.mean) * p1.ex; val.var += p1.ex / 2; } else { BBASSERT2(false); }}void GaussRect::MyUpdate(){ if (NumChildren() < 1) { return; } DSSet ng; // gradient from direct children DSSet rg; // gradient from children below the rectification ChildGradients(ng, rg); DSSet mpar, vpar; ParReal(0, mpar, DFlags(true)); ParReal(1, vpar, DFlags(false, false, true)); double x, vx, ivx, z; bool limitexceded = false; negval.var = 1 / (vpar.ex + 2 * ng.var); negval.mean = negval.var * (vpar.ex * mpar.mean + 2 * ng.var * expts.mean - ng.mean); ivx = 2 * rg.var; vx = 1 / ivx; x = rectexpts.mean - vx * rg.mean; posval.var = 1 / (ivx + 1 / negval.var); posval.mean = posval.var * (ivx * x + negval.mean / negval.var); posweight = NormPdf(x, negval.mean, vx + negval.var); negweight = NormPdf(x, 0, vx); z = 0.5 * posweight * Erfc(-posval.mean / sqrt(2 * posval.var)) + 0.5 * negweight * Erfc(negval.mean / sqrt(2 * negval.var)); if (z > GAUSSRECTLIMIT) { z = 1/z; } else { limitexceded = true; z = 1 / GAUSSRECTLIMIT; } BBASSERT2(finite(z)); posweight *= z; negweight *= z; BBASSERT2(finite(posweight)); BBASSERT2(finite(negweight)); if (limitexceded) { cout << "Warning: Limit exceded in " << GetLabel() << endl; } UpdateMoments(); UpdateExpectations(); costuptodate = false;}void GaussRect::ChildGradients(DSSet &norm, DSSet &rect){ for (size_t i = 0; i < children.size(); i++) { Rectification *rnode = dynamic_cast<Rectification *>(children[i]); if (rnode == 0) { // not the rectification node children[i]->GradReal(norm, this); } else { // yes, this is the rectification node rnode->GradReal(rect, this); } }}/* class GaussRectVState */GaussRectVState::GaussRectVState(GaussRectV *n){ node = n;}DVSet &GaussRectVState::GetPosVal(){ return node->posval;}DVSet &GaussRectVState::GetNegVal(){ return node->negval;}DV &GaussRectVState::GetPosWeights(){ return node->posweights;}DV &GaussRectVState::GetNegWeights(){ return node->negweights;}DV &GaussRectVState::GetPosMoment(int i){ return node->posmoments[i];}DV &GaussRectVState::GetNegMoment(int i){ return node->negmoments[i];} inline void ScaleWeights(DV *w){ //cout << "ScaleWeights begin" << endl; double sum = 0.0; for (size_t i = 0; i < w->size(); i++) { sum += (*w)[i]; } BBASSERT(sum != 0.0); for (size_t i = 0; i < w->size(); i++) { double val = (*w)[i] / sum; (*w)[i] = val > CATEGLIMIT ? val : 0.0; } //cout << "ScaleWeights end" << endl;}/* class MoGV */MoGV::MoGV(Net *net, Label label, Node *d) : Variable(net, label, d), NParNode(d){ cost = 0; CheckParent(0, DISCRETEV); numComponents = NumComponents(); expts.mean.resize(net->Time()); expts.var.resize(net->Time());}double MoGV::Cost(){ if (!clamped && children.empty()) return 0; if (!costuptodate) { DVH mpar, vpar; DFlags mflags(true, true); DFlags vflags(true, false, true); VDDH dpar; ParDiscreteV(0, dpar); double c = 0; cost = 0; double log2pi = log(2*PI); for (size_t k = 0; k < numComponents; k++) { means[k]->GetRealV(mpar, mflags); vars[k]->GetRealV(vpar, vflags); for (size_t i = 0; i < net->Time(); i++) { /* C_p */ c = vpar.Exp(i) * (Sqr(myval[k]->mean[i] - mpar.Mean(i)) + mpar.Var(i) + myval[k]->var[i]) - vpar.Mean(i) + log2pi;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -