📄 node.cc
字号:
saver->SetNamedDSSet("myval", myval); } Function::Save(saver);}// class SumN : public Function : public Nodebool SumN::GetReal(DSSet &val, DFlags req){ bool needm = req.mean && !uptodate.mean; bool needv = req.var && !uptodate.var; bool neede = false; //req.ex && !uptodate.ex; if (needm || needv || neede) { if (needm) myval.mean = 0; if (needv) myval.var = 0; // if (neede) myval.ex = 1; for (size_t i = 0; i<NumParents(); i++) { DSSet p; if (!ParReal(i, p, DFlags(needm, needv, neede))) return false; if (needm) { myval.mean += p.mean; } if (needv) { myval.var += p.var; } // if (neede) { // myval.ex *= p.ex; // } if (keepupdated) { parentval[i] = p; } } if (needm) uptodate.mean = true; if (needv) uptodate.var = true; // if (neede) uptodate.ex = true; } if (req.mean) {val.mean = myval.mean; req.mean = false;} if (req.var) {val.var = myval.var; req.var = false;} // if (req.ex) {val.ex = myval.ex; req.ex = false;} return req.AllFalse();}void SumN::GradReal(DSSet &val, const Node *ptr){ //int ide = ParIdentity(ptr); DSSet grad; ChildGradReal(grad); val.mean += grad.mean; val.var += grad.var; // if (grad.ex) { // DSSet p, tempself; // if (!uptodate.ex) { // GetReal(tempself, DFlags(false,false,true)); // } // ParReal(ide, p, DFlags(false, false, true)); // val.ex += grad.ex * myval.ex / p.ex; // }}bool SumN::AddParent(Node *n){ // The parents are checked in GetReal Node::AddParent(n, true); int ide = ParIdentity(n); if (keepupdated) { parentval.resize(ide+1,DSSet(0.0,0.0,1.0)); parentval[ide] = DSSet(0.0,0.0,1.0); // is this done already above? } uptodate = DFlags(false,false,false); OutdateChild(); return true;}void SumN::SetKeepUpdated(const bool _keepupdated){ keepupdated = _keepupdated; if (keepupdated) { // update now myval.mean = 0.0; myval.var = 0.0; // myval.ex = 1.0; for (size_t i = 0; i<NumParents(); i++) { DSSet p; ParReal(i, p, DFlags(true,true,false)); //.ex parentval[i] = p; myval.mean += p.mean; myval.var += p.var; // myval.ex *= p.ex; } uptodate.mean = true; uptodate.var = true; // uptodate.ex = true; }}void SumN::Outdate(const Node *ptr) { if (uptodate.mean || uptodate.var) { //.ex if (keepupdated) { int ide = ParIdentity(ptr); DSSet p; ParReal(ide, p, DFlags(true, true, false)); //.ex myval.mean += p.mean - parentval[ide].mean; myval.var += p.var - parentval[ide].var; if (myval.var<0) { uptodate.var = false; } // myval.ex *= p.ex / parentval[ide].ex; parentval[ide] = p; } else { uptodate = DFlags(false,false,false); } } OutdateChild();}void SumN::Save(NetSaver *saver){ if (saver->GetSaveFunctionValue()) { saver->SetNamedDSSet("myval", myval); } Function::Save(saver);}// class Relay : public Function : public Nodevoid Relay::Save(NetSaver *saver){ Function::Save(saver);}// abstract class Variable : public NodeVariable::Variable(Net *ptr, Label label, Node *n1, Node *n2) : Node(ptr, label){ persist = 1; // Usually variables need all parents to survive hookeflags = 0; net->AddVariable(this, label); clamped = false; costuptodate = false; if (n1) AddParent(n1, false); if (n2) AddParent(n2, false);}void Variable::Save(NetSaver *saver){ saver->SetNamedBool("clamped", clamped); saver->SetNamedBool("costuptodate", costuptodate); saver->SetNamedInt("hookeflags", hookeflags); Node::Save(saver);}void Variable::SaveState() { if (!clamped && !MySaveState() && (net->GetDebugLevel() > 5)) cerr << "SaveState not supported" << endl;}void Variable::SaveStep() { if (!clamped && !MySaveStep() && (net->GetDebugLevel() > 5)) cerr << "SaveStep not supported" << endl;}void Variable::RepeatStep(double alpha) { if (!clamped) { MyRepeatStep(alpha); OutdateChild(); }}void Variable::SaveRepeatedState(double alpha) { if (!clamped && !MySaveRepeatedState(alpha) && (net->GetDebugLevel() > 5)) cerr << label << ": SaveRepeatedState not supported" << endl;}void Variable::ClearStateAndStep() { if (!clamped && !MyClearStateAndStep() && (net->GetDebugLevel() > 5)) cerr << "ClearStateAndStep not supported" << endl;}// class Gaussian : public Variable : public NodeGaussian::Gaussian(Net *net, Label label, Node *m, Node *v) : Variable(net, label, m, v), BiParNode(m, v){ sstate = 0; sstep = 0; cost = 0; CheckParent(0, REAL_MV); CheckParent(1, REAL_ME); DSSet p0, p1; ParReal(0, p0, DFlags(true)); ParReal(1, p1, DFlags(false, false, true));// myval.mean = p0.mean;// myval.var = 1/p1.ex; myval.mean = 0.0; myval.var = 1.0; exuptodate = false; costuptodate = false;}void Gaussian::GetState(DV *state, size_t t = 0){ BBASSERT2(t == 0); state->resize(2); (*state)[0] = myval.mean; (*state)[1] = myval.var;}void Gaussian::SetState(DV *state, size_t t = 0){ BBASSERT2(t == 0); BBASSERT2(state->size() == 2); myval.mean = (*state)[0]; myval.var = (*state)[1]; costuptodate = false; exuptodate = false; OutdateChild();}double Gaussian::Cost(){ if (!clamped && children.empty()) return 0; if (!costuptodate) { DSSet p0, p1; ParReal(0, p0, DFlags(true, true)); ParReal(1, p1, DFlags(true, false, true)); if (clamped) { //assert(myval.var == 0); cost = ((Sqr(myval.mean - p0.mean) + p0.var + myval.var) * p1.ex - p1.mean) / 2 + _5LOG2PI; } else cost = ((Sqr(myval.mean - p0.mean) + p0.var + myval.var) * p1.ex - p1.mean - log(myval.var) - 1) / 2; costuptodate = true; } return cost;}bool Gaussian::GetReal(DSSet &val, DFlags req){ if (req.ex && !exuptodate) { myval.ex = exp(myval.mean+myval.var/2); exuptodate = true; } if (req.mean) {val.mean = myval.mean; req.mean = false;} if (req.var) {val.var = myval.var; req.var = false;} if (req.ex) {val.ex = myval.ex; req.ex = false;} return req.AllFalse();}void Gaussian::GradReal(DSSet &val, const Node *ptr){ if (!clamped && children.empty()) return; if (ParIdentity(ptr)) { // ParMean(1) DSSet p0; ParReal(0, p0, DFlags(true, true)); val.mean -= 0.5; if (clamped) { //assert(myval.var == 0); val.ex += (Sqr(myval.mean - p0.mean) + p0.var + myval.var) / 2; } else // !clamped val.ex += (Sqr(myval.mean - p0.mean) + p0.var + myval.var) / 2; } else { // ParMean(0) DSSet p0, p1; ParReal(0, p0, DFlags(true)); ParReal(1, p1, DFlags(false, false, true)); val.mean += (p0.mean - myval.mean) * p1.ex; val.var += p1.ex / 2; }}void Gaussian::Save(NetSaver *saver){ saver->SetNamedDSSet("myval", myval); saver->SetNamedDouble("cost", cost); saver->SetNamedBool("exuptodate", exuptodate); if (sstate) saver->SetNamedDSSet("sstate", *sstate); if (sstep) saver->SetNamedDSSet("sstep", *sstep); Variable::Save(saver);}void VarNewton(double &mean, double &var, double gme, double gva, double gex, Label label){ if (!gex) { var = 0.5 / gva; mean -= var * gme; } else { // Solve the minimum of gex * exp(mean + var/2) + gme * mean + // gva * [(mean - current_mean)^2 + var] - 0.5 * log(var) double oldm = mean, oldv = var, mstep, vstep, coef; double newc, oldc = gex * exp(mean + var/2) + gme * mean + gva * var - 0.5 * log(var); int i = 0; do { mstep = -(gme + 2 * gva * (mean - oldm) + gex * exp(mean+var/2)) / (2 * gva + gex * exp(mean+var/2)); mean += (mstep > MAXSTEP) ? MAXSTEP : mstep; vstep = 1 / (2 * gva + gex * exp(mean + var/2)) - var; coef = var * (0.5 - gva * var); if (coef > 0) vstep /= 1 + coef; var += (vstep > MAXSTEP) ? MAXSTEP : vstep; if (++i >= 100) { cerr << label << " VarNewton: M=" << oldm << "; V=" << oldv << "; GEX=" << gex << "; GVA=" << gva << "; GME=" << gme << ": mstep = " << mstep << ", vstep = " << vstep << '\n'; mstep = vstep = 0; } } while (fabs(mstep) > MINSTEP || fabs(vstep) > MINSTEP); newc = gex * exp(mean + var/2) + gme * mean + gva * (Sqr(mean - oldm) + var) - 0.5 * log(var); if (newc > oldc + EPSILON) cerr << label << " VarNewton: M=" << oldm << "; V=" << oldv << "; GEX=" << gex << "; GVA=" << gva << "; GME=" << gme << ": diff = " << newc - oldc << '\n'; }}void Gaussian::MyPartialUpdate(IntV *indices){ MyUpdate();}void Gaussian::MyUpdate(){ if (NumChildren() == 0) { return; } DSSet grad, p0, p1; ChildGradReal(grad); ParReal(0, p0, DFlags(true)); ParReal(1, p1, DFlags(false, false, true)); VarNewton(myval.mean, myval.var, (myval.mean - p0.mean) * p1.ex + grad.mean, p1.ex/2 + grad.var, grad.ex, label); exuptodate = false; costuptodate = false;}bool Gaussian::MyClamp(double m){ myval.mean = m; myval.var = 0; exuptodate = false; return true;}bool Gaussian::MyClamp(double m, double v){ myval.mean = m; myval.var = v; exuptodate = false; return true;}bool Gaussian::MySaveState(){ if (!sstate) sstate = new DSSet; sstate->mean = myval.mean; sstate->var = log(myval.var); return true;}bool Gaussian::MySaveStep(){ if (!sstate) return false; if (!sstep) sstep = new DSSet; switch (hookeflags) { case 0: case 1: sstep->mean = myval.mean - sstate->mean; sstep->var = log(myval.var) - sstate->var; break; case 2: case 3: if (sstate->mean == 0) sstep->mean = -1.0; else sstep->mean = myval.mean / sstate->mean; sstep->var = log(myval.var) - sstate->var; break; } return true;}bool Gaussian::MySaveRepeatedState(double alpha){ if (!sstate || !sstep) { cerr << label << ": No saved state when trying to repeat step!" << endl; return false; } switch (hookeflags) { case 0: case 2: sstate->mean = sstate->mean + alpha * sstep->mean; sstate->var = sstate->var + alpha * sstep->var; break; case 1: case 3: if (sstep->mean > 0) sstate->mean = sstate->mean * exp(alpha * log(sstep->mean)); break; } return true;}void Gaussian::MyRepeatStep(double alpha){ if (!sstate || !sstep) { cerr << label << ": No saved state when trying to repeat step!" << endl; return; } switch (hookeflags) { case 0: myval.mean = sstate->mean + alpha * sstep->mean; myval.var = exp(sstate->var + alpha * sstep->var); break; case 1: myval.mean = sstate->mean + alpha * sstep->mean; break; case 2: if (sstep->mean > 0) myval.mean = sstate->mean * exp(alpha * log(sstep->mean)); myval.var = exp(sstate->var + alpha * sstep->var); break; case 3: if (sstep->mean > 0) myval.mean = sstate->mean * exp(alpha * log(sstep->mean)); break; } exuptodate = false; costuptodate = false;}bool Gaussian::MyClearStateAndStep(){ if (sstate) { delete sstate; sstate = 0; } if (sstep) { delete sstep; sstep = 0; } return true;}/* class RectifiedGaussian */RectifiedGaussian::RectifiedGaussian(Net *net, Label label, Node *m, Node *v) : Variable(net, label, m, v), BiParNode(m, v){ cost = 0; CheckParent(0, REAL_MV); CheckParent(0, REAL_ME); myval.mean = 0; myval.var = 1; UpdateExpectations(); MyUpdate();}void RectifiedGaussian::GetState(DV *state, size_t t = 0){ BBASSERT2(t == 0); state->resize(2); (*state)[0] = myval.mean; (*state)[1] = myval.var;}void RectifiedGaussian::SetState(DV *state, size_t t = 0){ BBASSERT2(t == 0 && state->size() == 2); myval.mean = (*state)[0]; myval.var = (*state)[1]; UpdateExpectations(); costuptodate = false; OutdateChild();}void RectifiedGaussian::Save(NetSaver *saver){ saver->SetNamedDSSet("myval", myval); saver->SetNamedDSSet("expectations", expectations); saver->SetNamedDouble("cost", cost); Variable::Save(saver);}string RectifiedGaussian::GetType() const { return "RectifiedGaussian"; }double RectifiedGaussian::Cost(){ if (children.empty()) { return 0; } if (!costuptodate) { DSSet mpar, vpar; ParReal(0, mpar, DFlags(true, true, false)); ParReal(1, vpar, DFlags(true, false, true)); /* C_p */ cost = 0.5 * (vpar.ex * (Sqr(expectations.mean - mpar.mean) + expectations.var + mpar.var) - vpar.mean + log(PI/2)); // + log(2) if m != const0#if RECTIFIED_BETTER_APPROX cost += log(Erfc(-1/sqrt(2.0) * mpar.mean * exp(vpar.mean / 2 + vpar.var / 8)));#endif /* C_q */ cost += -1/(2*myval.var) * (expectations.var + Sqr(expectations.mean - myval.mean)) + 0.5 * log(2/(PI*myval.var)) - log(Erfc(-myval.mean / sqrt(2*myval.var))); costuptodate = true; } return cost;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -