⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 node.cc

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 CC
📖 第 1 页 / 共 5 页
字号:
//// This file is a part of the Bayes Blocks library//// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander// Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.//// This program is free software; you can redistribute it and/or modify// it under the terms of the GNU General Public License as published by// the Free Software Foundation; either version 2, or (at your option)// any later version.//// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the// GNU General Public License (included in file License.txt in the// program package) for more details.//// $Id: Node.cc 7 2006-10-26 10:26:41Z ah $#include "config.h"#ifdef WITH_PYTHON#include <Python.h>#define __PYTHON_H_INCLUDED__#endif#include "Templates.h"#include "Net.h"#include "Node.h"#include "Saver.h"#include <algorithm>//#include <assert.h>#include "SpecFun.h"const double MINSTEP = 1e-4;const double MAXSTEP = 4;const double EPSILON = 1.5e-8;  // ~ sqrt(eps) = 1.4901e-08const double NL_EPSILON = 1e-10;const double PI = 3.14159265358979323846;const double _5LOG2PI = 0.5 * log(2*PI);const double RECTLIMIT = -30;#define RECTIFIED_BETTER_APPROX 0const double CATEGLIMIT = 1e-40;const double GAUSSRECTLIMIT = 1e-40;inline double sign(double d){  if (d < 0)    return -1;  if (d > 0)    return 1;  else    return 0;}// abstract class NodeNode::Node(Net *ptr, Label mylabel){  persist = 0;  net = ptr;  if (net->GetNode(mylabel)) {    label = net->GetNextLabel(mylabel);  }  else    label = mylabel;  dying = false;  net->AddNode(this, label);  timetype = 0;}void Node::Die(int verbose){  size_t i;  if (verbose)    cout << "Node " << GetLabel() << " of type " << GetType() <<      " is dying" << endl;  dying = true;  net->NotifyDeath(this);  for (i = 0; i < NumParents(); i++)    GetParent(i)->NotifyDeath(this, verbose);  for (i = 0; i < children.size(); i++)    children[i]->NotifyDeath(this, verbose);}void Node::NotifyDeath(Node *ptr, int verbose){  int par, child;  NodeIterator it;  if (dying) return;  par = RemoveParent(ptr);  it = remove(children.begin(), children.end(), ptr);  child = children.end() - it;  children.erase(it, children.end());  if (par || child)    if ((1 & persist) && par)      Die(verbose);    else if ((2 & persist) && NumParents() == 0)      Die(verbose);    else if ((4 & persist) && children.empty())      Die(verbose);    else if ((8 & persist) && NumParents() == 1) {      while (children.size()) {	children[0]->ReplacePtr(this, GetParent(0));	GetParent(0)->AddChild(children[0]);	children.erase(children.begin());      }      Die(verbose);    }  if (par && !dying)    Outdate(ptr);}void Node::ReplacePtr(Node *oldptr, Node *newptr){  size_t i;  if (ParReplacePtr(oldptr, newptr))    Outdate(newptr);  for (i = 0; i < children.size(); i++)    if (children[i] == oldptr)      children[i] = newptr;}void Node::NotifyTimeType(int tt, int verbose){  size_t i;  if (timetype || !tt)    return;  timetype = tt;  if (verbose)    cout << "Node " << GetLabel() << " of type " << GetType() <<      ": time type = " << tt << endl;  for (i = 0; i < NumParents(); i++)    GetParent(i)->NotifyTimeType(tt, verbose);  for (i = 0; i < children.size(); i++)    children[i]->NotifyTimeType(tt, verbose);}void Node::AddParent(Node *ptr, bool really) {  if (ptr->GetDying()) {    ostringstream msg;    msg << "Parent " << ptr->GetLabel() << " for Node "	<< GetLabel() << " is dying";    throw StructureException(msg.str());  }  if (really)    ReallyAddParent(ptr);  ptr->AddChild(this);  if (ptr->TimeType() && !timetype)    NotifyTimeType(1);  if (!ptr->TimeType() && timetype) {    if (net->GetDebugLevel() > -1) {      cerr << "Warning: changing parent " << ptr->GetLabel() << " time type due to"	   << endl;      cerr << "addition of a new child " << label << endl;    }    ptr->NotifyTimeType(1);  }}int UniParNode::RemoveParent(const Node *ptr){  if (parent == ptr) {    parent = 0;    return 1;  }  else {    return 0;  }}bool BiParNode::ParReplacePtr(const Node *oldptr, Node *newptr){  bool ret = false;  size_t i;  for (i = 0; i < NumParents(); i++)    if (parents[i] == oldptr) {      parents[i] = newptr;      ret = true;    }  return ret;}int BiParNode::ParIdentity(const Node *ptr){  if (parents[0] == ptr)    return 0;  if (parents[1] == ptr)    return 1;  return -1;}void BiParNode::ReallyAddParent(Node *ptr){  parents[0] == 0 ? parents[0] = ptr : parents[1] = ptr;}int BiParNode::RemoveParent(const Node *ptr){  int par = 0;  if (parents[1] == ptr) {    parents[1] = 0;    par++;  }  if (parents[0] == ptr) {    parents[0] = parents[1];    parents[1] = 0;    par++;  }  return par;}bool NParNode::ParReplacePtr(const Node *oldptr, Node *newptr){  bool ret = false;  size_t i;  for (i = 0; i < NumParents(); i++)    if (parents[i] == oldptr) {      parents[i] = newptr;      parent_inds.erase(oldptr);      parent_inds[newptr] = i;      ret = true;    }  return ret;}int NParNode::ParIdentity(const Node *ptr){  map<const Node *, int>::iterator p = parent_inds.find(ptr);  // The node was found  if (p != parent_inds.end())    return p->second;  else {    return -1;  }}int NParNode::RemoveParent(const Node *ptr){  int par;  NodeIterator it;  it = remove(parents.begin(), parents.end(), ptr);  par = parents.end() - it;  parents.erase(it, parents.end());  parent_inds.clear();  for (size_t i=0; i < parents.size(); i++)    parent_inds[parents[i]] = i;  return par;}void Node::CheckParent(size_t parnum, partype_e partype){  DVH tmp_dvh;  DSSet tmp_dss;  DD *tmp_dd;  VDDH tmp_vddh;  Node *ptr = GetParent(parnum);  ostringstream msg;  switch(partype) {  case REAL_MV:    if (!ParReal(parnum, tmp_dss, DFlags(true, true))) {      msg << "Wrong type of parents in " << GetType() << " Node "	  << label << endl;      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"	  << ptr->GetType() << endl;      msg << " (Expected a scalar parent giving mean and variance)";      throw StructureException(msg.str());    }    break;  case REAL_ME:    if (!ParReal(parnum, tmp_dss, DFlags(true, false, true))) {      msg << "Wrong type of parents in " << GetType() << " Node "	  << label << endl;      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"	  << ptr->GetType() << endl;      msg << " (Expected a scalar parent giving mean and exp)";      throw StructureException(msg.str());    }    break;  case REAL_M:    if (!ParReal(parnum, tmp_dss, DFlags(true))) {      msg << "Wrong type of parents in " << GetType() << " Node "	  << label << endl;      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"	  << ptr->GetType() << endl;      msg << " (Expected a scalar parent giving mean)";      throw StructureException(msg.str());    }    break;  case REALV_MV:    if (!ParRealV(parnum, tmp_dvh, DFlags(true, true))) {      msg << "Wrong type of parents in " << GetType() << " Node "	  << label << endl;      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"	  << ptr->GetType() << endl;      msg << " (Expected a vector parent giving mean and variance)";      throw StructureException(msg.str());    }    break;  case REALV_ME:    if (!ParRealV(parnum, tmp_dvh, DFlags(true, false, true))) {      msg << "Wrong type of parents in " << GetType() << " Node "	  << label << endl;      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"	  << ptr->GetType() << endl;      msg << " (Expected a vector parent giving mean and exp)";      throw StructureException(msg.str());    }    break;  case REALV_M:    if (!ParRealV(parnum, tmp_dvh, DFlags(true))) {      msg << "Wrong type of parents in " << GetType() << " Node "	  << label << endl;      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"	  << ptr->GetType() << endl;      msg << " (Expected a vector parent giving mean)";      throw StructureException(msg.str());    }    break;  case DISCRETE:    if (!ParDiscrete(parnum, tmp_dd)) {      msg << "Wrong type of parents in " << GetType() << " Node "	  << label << endl;      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"	  << ptr->GetType() << endl;      msg << " (Expected a discrete parent)";      throw StructureException(msg.str());    }    break;  case DISCRETEV:    if (!ParDiscreteV(parnum, tmp_vddh)) {      msg << "Wrong type of parents in " << GetType() << " Node "	  << label << endl;      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"	  << ptr->GetType() << endl;      msg << " (Expected a discrete vector parent)";      throw StructureException(msg.str());    }    break;  }}void Node::ChildGradReal(DSSet &val){  for (size_t i = 0; i < children.size(); i++)    children[i]->GradReal(val, this);}void Node::ChildGradRealV(DVSet &val){  for (size_t i = 0; i < children.size(); i++)    children[i]->GradRealV(val, this);}void Node::ChildGradDiscrete(DD &val){  for (size_t i = 0; i < children.size(); i++)    children[i]->GradDiscrete(val, this);}void Node::ChildGradDiscreteV(VDD &val){  for (size_t i = 0; i < children.size(); i++)    children[i]->GradDiscreteV(val, this);}void Node::OutdateChild(){  for (size_t i = 0; i < children.size(); i++)    children[i]->Outdate(this);}void Node::Save(NetSaver *saver){  size_t i;  saver->SetNamedLabel("label", label);  saver->StartEnumCont(NumParents(), "parents");  for (i=0; i<NumParents(); i++)    saver->SetLabel(GetParent(i)->GetLabel());  saver->FinishEnumCont("parents");  saver->StartEnumCont(children.size(), "children");  for (i=0; i<children.size(); i++)    saver->SetLabel(children[i]->GetLabel());  saver->FinishEnumCont("children");  saver->SetNamedInt("persist", persist);  saver->SetNamedInt("timetype", timetype);  saver->SetNamedBool("dying", dying);}// class Constant : public Nodevoid Constant::Save(NetSaver *saver){  saver->SetNamedDouble("cval", cval);  Node::Save(saver);}// class ConstantV : public NodeConstantV::ConstantV(Net *net, Label label, DV v) : Node(net, label){  myval.mean = v;  myval.var.resize(v.size());  myval.ex.resize(v.size());  for (size_t i = 0; i < v.size(); i++) {    myval.ex[i] = exp(v[i]);  }}void ConstantV::Save(NetSaver *saver){  saver->SetNamedDVSet("myval", myval);  Node::Save(saver);}// abstract class Function : public NodeFunction::Function(Net *ptr, Label label, Node *n1, Node *n2) :  Node(ptr, label){  if (n1)    AddParent(n1, false);  if (n2)    AddParent(n2, false);  uptodate = DFlags(false,false,false);  persist = 1 | 4; // Functions usually need all parents and at least one child}void Function::Save(NetSaver *saver){  if (saver->GetSaveFunctionValue()) {    saver->SetNamedDFlags("uptodate", uptodate);  }  Node::Save(saver);}// class Prod : public Function : public Nodebool Prod::GetReal(DSSet &val, DFlags req){  bool needm = req.mean && !uptodate.mean;  bool needv = req.var && !uptodate.var;  if (needm || needv) {    DSSet p0, p1;    if (!ParReal(0, p0, DFlags(true, needv)) ||	!ParReal(1, p1, DFlags(true, needv)))      return false;    if (needm) {      mean = p0.mean * p1.mean;      uptodate.mean = true;    }    if (needv) {      var = (Sqr(p0.mean) + p0.var) * p1.var + p0.var * Sqr(p1.mean);      uptodate.var = true;    }  }  if (req.mean) {val.mean = mean; req.mean = false;}  if (req.var) {val.var = var; req.var = false;}  return req.AllFalse();}void Prod::GradReal(DSSet &val, const Node *ptr){  int ide = ParIdentity(ptr);  DSSet grad, p0, p1;  ChildGradReal(grad);  ParReal(ide, p0, DFlags(true));  ParReal(1-ide, p1, DFlags(true, true));  val.mean += grad.mean * p1.mean + 2 * grad.var * p0.mean * p1.var;  val.var += grad.var * (Sqr(p1.mean) + p1.var);}void Prod::Save(NetSaver *saver){  if (saver->GetSaveFunctionValue()) {    saver->SetNamedDouble("mean", mean);    saver->SetNamedDouble("var", var);  }  Function::Save(saver);}// class Sum2 : public Function : public Nodebool Sum2::GetReal(DSSet &val, DFlags req){  bool needm = req.mean && !uptodate.mean;  bool needv = req.var && !uptodate.var;  bool neede = req.ex && !uptodate.ex;  if (needm || needv || neede) {    DSSet p0, p1;    if (!ParReal(0, p0, DFlags(needm, needv, neede)) ||	!ParReal(1, p1, DFlags(needm, needv, neede)))      return false;    if (needm) {      myval.mean = p0.mean + p1.mean;      uptodate.mean = true;    }    if (needv) {      myval.var = p0.var + p1.var;      uptodate.var = true;    }    if (neede) {      myval.ex = p0.ex * p1.ex;      uptodate.ex = true;    }  }  if (req.mean) {val.mean = myval.mean; req.mean = false;}  if (req.var) {val.var = myval.var; req.var = false;}  if (req.ex) {val.ex = myval.ex; req.ex = false;}  return req.AllFalse();}void Sum2::GradReal(DSSet &val, const Node *ptr){  int ide = ParIdentity(ptr);  DSSet grad;  ChildGradReal(grad);  val.mean += grad.mean;  val.var += grad.var;  if (grad.ex) {    DSSet p1;    ParReal(1-ide, p1, DFlags(false, false, true));    val.ex += grad.ex * p1.ex;  }}void Sum2::Save(NetSaver *saver){  if (saver->GetSaveFunctionValue()) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -