📄 node.cc
字号:
//// This file is a part of the Bayes Blocks library//// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander// Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.//// This program is free software; you can redistribute it and/or modify// it under the terms of the GNU General Public License as published by// the Free Software Foundation; either version 2, or (at your option)// any later version.//// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the// GNU General Public License (included in file License.txt in the// program package) for more details.//// $Id: Node.cc 7 2006-10-26 10:26:41Z ah $#include "config.h"#ifdef WITH_PYTHON#include <Python.h>#define __PYTHON_H_INCLUDED__#endif#include "Templates.h"#include "Net.h"#include "Node.h"#include "Saver.h"#include <algorithm>//#include <assert.h>#include "SpecFun.h"const double MINSTEP = 1e-4;const double MAXSTEP = 4;const double EPSILON = 1.5e-8; // ~ sqrt(eps) = 1.4901e-08const double NL_EPSILON = 1e-10;const double PI = 3.14159265358979323846;const double _5LOG2PI = 0.5 * log(2*PI);const double RECTLIMIT = -30;#define RECTIFIED_BETTER_APPROX 0const double CATEGLIMIT = 1e-40;const double GAUSSRECTLIMIT = 1e-40;inline double sign(double d){ if (d < 0) return -1; if (d > 0) return 1; else return 0;}// abstract class NodeNode::Node(Net *ptr, Label mylabel){ persist = 0; net = ptr; if (net->GetNode(mylabel)) { label = net->GetNextLabel(mylabel); } else label = mylabel; dying = false; net->AddNode(this, label); timetype = 0;}void Node::Die(int verbose){ size_t i; if (verbose) cout << "Node " << GetLabel() << " of type " << GetType() << " is dying" << endl; dying = true; net->NotifyDeath(this); for (i = 0; i < NumParents(); i++) GetParent(i)->NotifyDeath(this, verbose); for (i = 0; i < children.size(); i++) children[i]->NotifyDeath(this, verbose);}void Node::NotifyDeath(Node *ptr, int verbose){ int par, child; NodeIterator it; if (dying) return; par = RemoveParent(ptr); it = remove(children.begin(), children.end(), ptr); child = children.end() - it; children.erase(it, children.end()); if (par || child) if ((1 & persist) && par) Die(verbose); else if ((2 & persist) && NumParents() == 0) Die(verbose); else if ((4 & persist) && children.empty()) Die(verbose); else if ((8 & persist) && NumParents() == 1) { while (children.size()) { children[0]->ReplacePtr(this, GetParent(0)); GetParent(0)->AddChild(children[0]); children.erase(children.begin()); } Die(verbose); } if (par && !dying) Outdate(ptr);}void Node::ReplacePtr(Node *oldptr, Node *newptr){ size_t i; if (ParReplacePtr(oldptr, newptr)) Outdate(newptr); for (i = 0; i < children.size(); i++) if (children[i] == oldptr) children[i] = newptr;}void Node::NotifyTimeType(int tt, int verbose){ size_t i; if (timetype || !tt) return; timetype = tt; if (verbose) cout << "Node " << GetLabel() << " of type " << GetType() << ": time type = " << tt << endl; for (i = 0; i < NumParents(); i++) GetParent(i)->NotifyTimeType(tt, verbose); for (i = 0; i < children.size(); i++) children[i]->NotifyTimeType(tt, verbose);}void Node::AddParent(Node *ptr, bool really) { if (ptr->GetDying()) { ostringstream msg; msg << "Parent " << ptr->GetLabel() << " for Node " << GetLabel() << " is dying"; throw StructureException(msg.str()); } if (really) ReallyAddParent(ptr); ptr->AddChild(this); if (ptr->TimeType() && !timetype) NotifyTimeType(1); if (!ptr->TimeType() && timetype) { if (net->GetDebugLevel() > -1) { cerr << "Warning: changing parent " << ptr->GetLabel() << " time type due to" << endl; cerr << "addition of a new child " << label << endl; } ptr->NotifyTimeType(1); }}int UniParNode::RemoveParent(const Node *ptr){ if (parent == ptr) { parent = 0; return 1; } else { return 0; }}bool BiParNode::ParReplacePtr(const Node *oldptr, Node *newptr){ bool ret = false; size_t i; for (i = 0; i < NumParents(); i++) if (parents[i] == oldptr) { parents[i] = newptr; ret = true; } return ret;}int BiParNode::ParIdentity(const Node *ptr){ if (parents[0] == ptr) return 0; if (parents[1] == ptr) return 1; return -1;}void BiParNode::ReallyAddParent(Node *ptr){ parents[0] == 0 ? parents[0] = ptr : parents[1] = ptr;}int BiParNode::RemoveParent(const Node *ptr){ int par = 0; if (parents[1] == ptr) { parents[1] = 0; par++; } if (parents[0] == ptr) { parents[0] = parents[1]; parents[1] = 0; par++; } return par;}bool NParNode::ParReplacePtr(const Node *oldptr, Node *newptr){ bool ret = false; size_t i; for (i = 0; i < NumParents(); i++) if (parents[i] == oldptr) { parents[i] = newptr; parent_inds.erase(oldptr); parent_inds[newptr] = i; ret = true; } return ret;}int NParNode::ParIdentity(const Node *ptr){ map<const Node *, int>::iterator p = parent_inds.find(ptr); // The node was found if (p != parent_inds.end()) return p->second; else { return -1; }}int NParNode::RemoveParent(const Node *ptr){ int par; NodeIterator it; it = remove(parents.begin(), parents.end(), ptr); par = parents.end() - it; parents.erase(it, parents.end()); parent_inds.clear(); for (size_t i=0; i < parents.size(); i++) parent_inds[parents[i]] = i; return par;}void Node::CheckParent(size_t parnum, partype_e partype){ DVH tmp_dvh; DSSet tmp_dss; DD *tmp_dd; VDDH tmp_vddh; Node *ptr = GetParent(parnum); ostringstream msg; switch(partype) { case REAL_MV: if (!ParReal(parnum, tmp_dss, DFlags(true, true))) { msg << "Wrong type of parents in " << GetType() << " Node " << label << endl; msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":" << ptr->GetType() << endl; msg << " (Expected a scalar parent giving mean and variance)"; throw StructureException(msg.str()); } break; case REAL_ME: if (!ParReal(parnum, tmp_dss, DFlags(true, false, true))) { msg << "Wrong type of parents in " << GetType() << " Node " << label << endl; msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":" << ptr->GetType() << endl; msg << " (Expected a scalar parent giving mean and exp)"; throw StructureException(msg.str()); } break; case REAL_M: if (!ParReal(parnum, tmp_dss, DFlags(true))) { msg << "Wrong type of parents in " << GetType() << " Node " << label << endl; msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":" << ptr->GetType() << endl; msg << " (Expected a scalar parent giving mean)"; throw StructureException(msg.str()); } break; case REALV_MV: if (!ParRealV(parnum, tmp_dvh, DFlags(true, true))) { msg << "Wrong type of parents in " << GetType() << " Node " << label << endl; msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":" << ptr->GetType() << endl; msg << " (Expected a vector parent giving mean and variance)"; throw StructureException(msg.str()); } break; case REALV_ME: if (!ParRealV(parnum, tmp_dvh, DFlags(true, false, true))) { msg << "Wrong type of parents in " << GetType() << " Node " << label << endl; msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":" << ptr->GetType() << endl; msg << " (Expected a vector parent giving mean and exp)"; throw StructureException(msg.str()); } break; case REALV_M: if (!ParRealV(parnum, tmp_dvh, DFlags(true))) { msg << "Wrong type of parents in " << GetType() << " Node " << label << endl; msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":" << ptr->GetType() << endl; msg << " (Expected a vector parent giving mean)"; throw StructureException(msg.str()); } break; case DISCRETE: if (!ParDiscrete(parnum, tmp_dd)) { msg << "Wrong type of parents in " << GetType() << " Node " << label << endl; msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":" << ptr->GetType() << endl; msg << " (Expected a discrete parent)"; throw StructureException(msg.str()); } break; case DISCRETEV: if (!ParDiscreteV(parnum, tmp_vddh)) { msg << "Wrong type of parents in " << GetType() << " Node " << label << endl; msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":" << ptr->GetType() << endl; msg << " (Expected a discrete vector parent)"; throw StructureException(msg.str()); } break; }}void Node::ChildGradReal(DSSet &val){ for (size_t i = 0; i < children.size(); i++) children[i]->GradReal(val, this);}void Node::ChildGradRealV(DVSet &val){ for (size_t i = 0; i < children.size(); i++) children[i]->GradRealV(val, this);}void Node::ChildGradDiscrete(DD &val){ for (size_t i = 0; i < children.size(); i++) children[i]->GradDiscrete(val, this);}void Node::ChildGradDiscreteV(VDD &val){ for (size_t i = 0; i < children.size(); i++) children[i]->GradDiscreteV(val, this);}void Node::OutdateChild(){ for (size_t i = 0; i < children.size(); i++) children[i]->Outdate(this);}void Node::Save(NetSaver *saver){ size_t i; saver->SetNamedLabel("label", label); saver->StartEnumCont(NumParents(), "parents"); for (i=0; i<NumParents(); i++) saver->SetLabel(GetParent(i)->GetLabel()); saver->FinishEnumCont("parents"); saver->StartEnumCont(children.size(), "children"); for (i=0; i<children.size(); i++) saver->SetLabel(children[i]->GetLabel()); saver->FinishEnumCont("children"); saver->SetNamedInt("persist", persist); saver->SetNamedInt("timetype", timetype); saver->SetNamedBool("dying", dying);}// class Constant : public Nodevoid Constant::Save(NetSaver *saver){ saver->SetNamedDouble("cval", cval); Node::Save(saver);}// class ConstantV : public NodeConstantV::ConstantV(Net *net, Label label, DV v) : Node(net, label){ myval.mean = v; myval.var.resize(v.size()); myval.ex.resize(v.size()); for (size_t i = 0; i < v.size(); i++) { myval.ex[i] = exp(v[i]); }}void ConstantV::Save(NetSaver *saver){ saver->SetNamedDVSet("myval", myval); Node::Save(saver);}// abstract class Function : public NodeFunction::Function(Net *ptr, Label label, Node *n1, Node *n2) : Node(ptr, label){ if (n1) AddParent(n1, false); if (n2) AddParent(n2, false); uptodate = DFlags(false,false,false); persist = 1 | 4; // Functions usually need all parents and at least one child}void Function::Save(NetSaver *saver){ if (saver->GetSaveFunctionValue()) { saver->SetNamedDFlags("uptodate", uptodate); } Node::Save(saver);}// class Prod : public Function : public Nodebool Prod::GetReal(DSSet &val, DFlags req){ bool needm = req.mean && !uptodate.mean; bool needv = req.var && !uptodate.var; if (needm || needv) { DSSet p0, p1; if (!ParReal(0, p0, DFlags(true, needv)) || !ParReal(1, p1, DFlags(true, needv))) return false; if (needm) { mean = p0.mean * p1.mean; uptodate.mean = true; } if (needv) { var = (Sqr(p0.mean) + p0.var) * p1.var + p0.var * Sqr(p1.mean); uptodate.var = true; } } if (req.mean) {val.mean = mean; req.mean = false;} if (req.var) {val.var = var; req.var = false;} return req.AllFalse();}void Prod::GradReal(DSSet &val, const Node *ptr){ int ide = ParIdentity(ptr); DSSet grad, p0, p1; ChildGradReal(grad); ParReal(ide, p0, DFlags(true)); ParReal(1-ide, p1, DFlags(true, true)); val.mean += grad.mean * p1.mean + 2 * grad.var * p0.mean * p1.var; val.var += grad.var * (Sqr(p1.mean) + p1.var);}void Prod::Save(NetSaver *saver){ if (saver->GetSaveFunctionValue()) { saver->SetNamedDouble("mean", mean); saver->SetNamedDouble("var", var); } Function::Save(saver);}// class Sum2 : public Function : public Nodebool Sum2::GetReal(DSSet &val, DFlags req){ bool needm = req.mean && !uptodate.mean; bool needv = req.var && !uptodate.var; bool neede = req.ex && !uptodate.ex; if (needm || needv || neede) { DSSet p0, p1; if (!ParReal(0, p0, DFlags(needm, needv, neede)) || !ParReal(1, p1, DFlags(needm, needv, neede))) return false; if (needm) { myval.mean = p0.mean + p1.mean; uptodate.mean = true; } if (needv) { myval.var = p0.var + p1.var; uptodate.var = true; } if (neede) { myval.ex = p0.ex * p1.ex; uptodate.ex = true; } } if (req.mean) {val.mean = myval.mean; req.mean = false;} if (req.var) {val.var = myval.var; req.var = false;} if (req.ex) {val.ex = myval.ex; req.ex = false;} return req.AllFalse();}void Sum2::GradReal(DSSet &val, const Node *ptr){ int ide = ParIdentity(ptr); DSSet grad; ChildGradReal(grad); val.mean += grad.mean; val.var += grad.var; if (grad.ex) { DSSet p1; ParReal(1-ide, p1, DFlags(false, false, true)); val.ex += grad.ex * p1.ex; }}void Sum2::Save(NetSaver *saver){ if (saver->GetSaveFunctionValue()) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -