📄 node.cc
字号:
if (!clamped) { /* C_q */ c -= log(myval[k]->var[i]) + log2pi + 1; } cost += 0.5 * dpar[i][k] * c; } } costuptodate = true; } return cost;}bool MoGV::GetRealV(DVH &val, DFlags req){ val.vec = &expts; return !req.ex;}void MoGV::GetMyvalV(DVH &val, int k){ val.vec = myval[k];}bool MoGV::IsMeanParent(const Node *ptr){ return WhichMeanParent(ptr) != -1;}bool MoGV::IsVarParent(const Node *ptr){ return WhichVarParent(ptr) != -1;}int MoGV::WhichParent(const Node *ptr, const vector<Node*> &parents){ for (size_t i = 0; i < parents.size(); i++) { if (ptr == parents[i]) { return i; } } return -1;}int MoGV::WhichMeanParent(const Node *ptr){ return WhichParent(ptr, means);}int MoGV::WhichVarParent(const Node *ptr){ return WhichParent(ptr, vars);}void MoGV::GradReal(DSSet &val, const Node *ptr){ //cout << "MoGV::GradReal() begin" << endl; if (!clamped && children.empty()) { return; } if (IsMeanParent(ptr)) { //cout << "mean" << endl; int k = WhichMeanParent(ptr); DVH mpar, vpar; VDDH dpar; means[k]->GetRealV(mpar, DFlags(true)); vars[k]->GetRealV(vpar, DFlags(false, false, true)); ParDiscreteV(0, dpar); for (size_t i = 0; i < net->Time(); i++) { val.mean += dpar[i][k] * (mpar.Mean(i) - myval[k]->mean[i]) * vpar.Exp(i); val.var += dpar[i][k] * vpar.Exp(i) / 2; } } else if (IsVarParent(ptr)) { //cout << "var" << endl; int k = WhichVarParent(ptr); //cout << "k = " << k << endl; DVH mpar; VDDH dpar; //cout << "GetRealV" << endl; means[k]->GetRealV(mpar, DFlags(true, true)); //cout << "ParDiscreteV" << endl; ParDiscreteV(0, dpar); for (size_t i = 0; i < net->Time(); i++) { val.mean -= dpar[i][k] * 0.5; val.ex += dpar[i][k] * (Sqr(myval[k]->mean[i] - mpar.Mean(i)) + mpar.Var(i) + myval[k]->var[i]) / 2; } } else { BBASSERT2(false); } //cout << "MoGV::GradReal() end" << endl;}void MoGV::GradRealV(DVSet &val, const Node *ptr){ //cout << "MoGV::GradRealV()" << endl; if (!clamped && children.empty()) { return; } if (IsMeanParent(ptr)) { int k = WhichMeanParent(ptr); DVH mpar, vpar; means[k]->GetRealV(mpar, DFlags(true)); vars[k]->GetRealV(vpar, DFlags(false, false, true)); VDDH dpar; ParDiscreteV(0, dpar); val.mean.resize(net->Time()); val.var.resize(net->Time()); for (size_t i = 0; i < net->Time(); i++) { val.mean[i] += dpar[i][k] * (mpar.Mean(i) - myval[k]->mean[i]) * vpar.Exp(i); val.var[i] += dpar[i][k] * vpar.Exp(i) / 2; } } else if (IsVarParent(ptr)) { int k = WhichVarParent(ptr); DVH mpar; means[k]->GetRealV(mpar, DFlags(true, true)); VDDH dpar; ParDiscreteV(0, dpar); val.mean.resize(net->Time()); val.ex.resize(net->Time()); for (size_t i = 0; i < net->Time(); i++) { val.mean[i] -= dpar[i][k] * 0.5; val.ex[i] += dpar[i][k] * (Sqr(myval[k]->mean[i] - mpar.Mean(i)) + mpar.Var(i) + myval[k]->var[i]) / 2; } } else { BBASSERT2(false); }}void MoGV::GradDiscreteV(VDD &val, const Node *ptr){ //cout << "MoGV::GradDiscreteV begin" << endl; if (!clamped && children.empty()) { return; } BBASSERT2(ptr == GetParent(0)); val.Resize(net->Time()); val.ResizeDD(numComponents); DVSet grad; ChildGradRealV(grad); for (size_t k = 0; k < numComponents; k++) { DVH mpar, vpar; means[k]->GetRealV(mpar, DFlags(true, true)); vars[k]->GetRealV(vpar, DFlags(true, false, true)); DVSet &mv = *(myval[k]); for (size_t i = 0; i < net->Time(); i++) { // constants w.r.t. k are dropped because they only affect the scaling // C_ps val[i][k] += 0.5 * (vpar.Exp(i) * (Sqr(mv.mean[i] - mpar.Mean(i)) + mpar.Var(i) + mv.var[i]) - vpar.Mean(i)); // C_qs val[i][k] -= 0.5 * log(mv.var[i]); // C_px double evx = 2 * grad.var[i]; double x = expts.mean[i] - grad.mean[i] / evx; val[i][k] += 0.5 * (evx * (Sqr(x - mv.mean[i]) + mv.var[i])); } } //cout << "MoGV::GradDiscreteV end" << endl;}string MoGV::GetType() const{ return "MoGV";}void MoGV::Save(NetSaver *saver){ saver->SetNamedDVSet("expts", expts); saver->SetNamedDouble("cost", cost); BBASSERT2(means.size() == numComponents); BBASSERT2(vars.size() == numComponents); BBASSERT2(myval.size() == numComponents); saver->StartEnumCont(numComponents, "means"); for (size_t i = 0; i < numComponents; i++) { saver->SetLabel(means[i]->GetLabel()); } saver->FinishEnumCont("means"); saver->StartEnumCont(numComponents, "vars"); for (size_t i = 0; i < numComponents; i++) { saver->SetLabel(vars[i]->GetLabel()); } saver->FinishEnumCont("vars"); saver->StartEnumCont(numComponents, "myval"); for (size_t i = 0; i < numComponents; i++) { saver->SetDVSet(*myval[i]); } saver->FinishEnumCont("myval"); Variable::Save(saver);}void MoGV::AddComponent(Node *m, Node *v){ if (means.size() + 1 > NumComponents()) throw StructureException("MoGV::AddComponent: too many components"); BBASSERT2(means.size() == vars.size()); BBASSERT2(NumParents() == 2 * means.size() + 1); means.push_back(m); vars.push_back(v); DVSet* val = new DVSet(); val->mean.resize(net->Time()); val->var.resize(net->Time()); myval.push_back(val); BBASSERT2(myval.size() == means.size()); AddParent(m, true); AddParent(v, true);} size_t MoGV::NumComponents(){ VDDH dhandle; Node* d = GetParent(0); BBASSERT2(d != NULL); d->GetDiscreteV(dhandle); BBASSERT2(dhandle.vec != NULL); numComponents = dhandle.vec->DDsize(); return numComponents;}void MoGV::MyUpdate(){ //cout << "MoGV::MyUpdate() begin" << endl; if (NumChildren() == 0) { return; } size_t K = NumComponents(); size_t N = net->Time(); BBASSERT2(means.size() == K && vars.size() == K); DVH mpar, vpar; DFlags mflags(true); DFlags vflags(false, false, true); DVSet grad; ChildGradRealV(grad); BBASSERT2(grad.mean.size() > 0); for (size_t i = 0; i < N; i++) { for (size_t k = 0; k < K; k++) { means[k]->GetRealV(mpar, mflags); vars[k]->GetRealV(vpar, vflags); myval[k]->var[i] = 1 / (2 * grad.var[i] + vpar.Exp(i)); myval[k]->mean[i] = myval[k]->var[i] * (2 * grad.var[i] * expts.mean[i] - grad.mean[i] + vpar.Exp(i) * mpar.Mean(i)); } } ComputeExpectations(); costuptodate = false; //cout << "MoGV::MyUpdate() end" << endl;}void MoGV::ComputeExpectations(){ VDDH dpar; ParDiscreteV(0, dpar); size_t N = net->Time(); size_t K = NumComponents(); for (size_t i = 0; i < N; i++) { double mean = 0.0; double var = 0.0; for (size_t k = 0; k < K; k++) { mean += dpar[i][k] * myval[k]->mean[i]; var += dpar[i][k] * (myval[k]->var[i] + Sqr(myval[k]->mean[i])); } var -= Sqr(mean); expts.mean[i] = mean; expts.var[i] = var; }}bool MoGV::MyClamp(const DV &m){ if (m.size() == expts.mean.size()) { copy(m.begin(), m.end(), expts.mean.begin()); } else { ostringstream msg; msg << "MoGV::MyClamp: wrong vector size " << m.size() << " != " << expts.mean.size(); throw TypeException(msg.str()); } fill(expts.var.begin(), expts.var.end(), 0.0); return true;}/* MoG */MoG::MoG(Net *net, Label label, Node *d) : Variable(net, label, d), NParNode(d){ cost = 0; CheckParent(0, DISCRETE); numComponents = NumComponents();}double MoG::Cost(){ if (!clamped && children.empty()) return 0; if (!costuptodate) { DSSet mpar, vpar; DFlags mflags(true, true); DFlags vflags(true, false, true); DD *dpar; ParDiscrete(0, dpar); double c = 0; cost = 0; double log2pi = log(2*PI); for (size_t k = 0; k < numComponents; k++) { means[k]->GetReal(mpar, mflags); vars[k]->GetReal(vpar, vflags); /* C_p */ c = vpar.ex * (Sqr(myval[k]->mean - mpar.mean) + mpar.var + myval[k]->var) - vpar.mean + log2pi; if (!clamped) { /* C_q */ c -= log(myval[k]->var) + log2pi + 1; } cost += 0.5 * dpar->Get(k) * c; } } costuptodate = true; return cost;}bool MoG::GetReal(DSSet &val, DFlags req){ val = expts; return !req.ex;}bool MoG::IsMeanParent(const Node *ptr){ return WhichMeanParent(ptr) != -1;}bool MoG::IsVarParent(const Node *ptr){ return WhichVarParent(ptr) != -1;}int MoG::WhichParent(const Node *ptr, const vector<Node*> &parents){ for (size_t i = 0; i < parents.size(); i++) { if (ptr == parents[i]) { return i; } } return -1;}int MoG::WhichMeanParent(const Node *ptr){ return WhichParent(ptr, means);}int MoG::WhichVarParent(const Node *ptr){ return WhichParent(ptr, vars);}void MoG::GradReal(DSSet &val, const Node *ptr){ //cout << "MoG::GradReal() begin" << endl; if (!clamped && children.empty()) { return; } if (IsMeanParent(ptr)) { //cout << "mean" << endl; int k = WhichMeanParent(ptr); DSSet mpar, vpar; DD *dpar; means[k]->GetReal(mpar, DFlags(true)); vars[k]->GetReal(vpar, DFlags(false, false, true)); ParDiscrete(0, dpar); val.mean += dpar->Get(k) * (mpar.mean - myval[k]->mean) * vpar.ex; val.var += dpar->Get(k) * vpar.ex / 2; } else if (IsVarParent(ptr)) { //cout << "var" << endl; int k = WhichVarParent(ptr); //cout << "k = " << k << endl; DSSet mpar; DD *dpar; //cout << "GetRealV" << endl; means[k]->GetReal(mpar, DFlags(true, true)); //cout << "ParDiscreteV" << endl; ParDiscrete(0, dpar); val.mean -= dpar->Get(k) * 0.5; val.ex += dpar->Get(k) * (Sqr(myval[k]->mean - mpar.mean) + mpar.var + myval[k]->var) / 2; } else { BBASSERT2(false); } //cout << "MoG::GradReal() end" << endl;}void MoG::GradDiscrete(DD &val, const Node *ptr){ //cout << "MoG::GradDiscreteV begin" << endl; if (!clamped && children.empty()) { return; } val.Resize(numComponents); BBASSERT2(ptr == GetParent(0)); DSSet grad; ChildGradReal(grad); for (size_t k = 0; k < numComponents; k++) { DSSet mpar, vpar; means[k]->GetReal(mpar, DFlags(true, true)); vars[k]->GetReal(vpar, DFlags(true, false, true)); DSSet &mv = *(myval[k]); double tmp; // C_ps tmp = 0.5 * (vpar.ex * (Sqr(mv.mean - mpar.mean) + mpar.var + mv.var) - vpar.mean); // C_qs tmp -= 0.5 * log(mv.var); // C_px double evx = 2 * grad.var; double x = expts.mean - grad.mean / evx; tmp += 0.5 * (evx * (Sqr(x - mv.mean) + mv.var)); val.Set(k, tmp); } //cout << "MoG::GradDiscreteV end" << endl;}string MoG::GetType() const{ return "MoG";}void MoG::Save(NetSaver *saver){ saver->SetNamedDSSet("expts", expts); saver->SetNamedDouble("cost", cost); BBASSERT2(means.size() == numComponents); BBASSERT2(vars.size() == numComponents); BBASSERT2(myval.size() == numComponents); saver->StartEnumCont(numComponents, "means"); for (size_t i = 0; i < numComponents; i++) { saver->SetLabel(means[i]->GetLabel()); } saver->FinishEnumCont("means"); saver->StartEnumCont(numComponents, "vars"); for (size_t i = 0; i < numComponents; i++) { saver->SetLabel(vars[i]->GetLabel()); } saver->FinishEnumCont("vars"); saver->StartEnumCont(numComponents, "myval"); for (size_t i = 0; i < numComponents; i++) { saver->SetDSSet(*myval[i]); } saver->FinishEnumCont("myval"); Variable::Save(saver);}void MoG::AddComponent(Node *m, Node *v){ if (means.size() + 1 > NumComponents()) { throw StructureException("MoG::AddComponent: too many components"); } BBASSERT2(means.size() == vars.size()); BBASSERT2(NumParents() == 2 * means.size() + 1); means.push_back(m); vars.push_back(v); DSSet* val = new DSSet(); myval.push_back(val); BBASSERT2(myval.size() == means.size()); AddParent(m,
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -