⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 net.cc

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 CC
📖 第 1 页 / 共 2 页
字号:
  std::set<Node *> tocheck;  std::set<Node *>::iterator curnode;  vector<Node *> stack;  Node *node;  CleanUp();  SortNodes();  for (size_t i = 0; i < variables.size(); i++) {    tocheck.insert((Node *)variables[i]);  }  for (size_t i = 0; i < nodes.size(); i++) {    if (nodes[i]->GetType() == "DelayV") {      tocheck.insert(nodes[i]);    }  }  for (curnode = tocheck.begin(); curnode != tocheck.end(); ++curnode) {    stack.clear();    stack.push_back(*curnode);    checkednodes.clear();    checkednodes.insert(*curnode);    while (!stack.empty()) {      node = stack.back();      stack.pop_back();      for (size_t i = 0; i < node->NumChildren(); i++) {	if (checkednodes.find(node->GetChild(i)) != checkednodes.end()) {	  throw StructureException("Invalid structure of the network. More than one route from " +				   (*curnode)->GetLabel() + " to " +				   node->GetChild(i)->GetLabel() + ".");	}	checkednodes.insert(node->GetChild(i));	if (tocheck.find(node->GetChild(i)) == tocheck.end()) {	  stack.push_back(node->GetChild(i));	}      }    }  }}bool Net::RegisterDecay(Decayer *d, string hook) {  multimap<string, Decayer *>::iterator i;  for (i = decay_hooks.lower_bound(hook);       i != decay_hooks.upper_bound(hook); i++) {    if ((*i).second == d)      return false;  }  if (GetDebugLevel() > 10)    cout << "Registering decay " << d << " for hook " << hook << endl;  decay_hooks.insert(multimap<string, Decayer *>::value_type(hook, d));  d->RegisterToHook(hook);  return true;}bool Net::UnregisterDecay(Decayer *d) {  multimap<string, Decayer *>::iterator i;  bool retval = false;  i = decay_hooks.begin();  while (i != decay_hooks.end()) {    if ((*i).second == d) {      if (GetDebugLevel() > 10)	cout << "Removing decay " << d << " from hook " << (*i).first << endl;      d->EraseHook((*i).first);      decay_hooks.erase(i);      retval = true;      i = decay_hooks.begin();    }    else      i++;  }  return retval;}bool Net::UnregisterDecayFromHook(Decayer *d, string hook) {  multimap<string, Decayer *>::iterator i;  for (i = decay_hooks.lower_bound(hook);       i != decay_hooks.upper_bound(hook); i++) {    if ((*i).second == d) {      if (GetDebugLevel() > 10)	cout << "Removing decay " << d << " from hook " << hook << endl;      d->EraseHook(hook);      decay_hooks.erase(i);      return true;    }  }  return false;}bool Net::ProcessDecayHook(string hook) {  multimap<string, Decayer *>::iterator i;  std::set<Decayer *> killlist;  bool val, rval = true;  if (GetDebugLevel() > 10)    cout << "Processing hook " << hook << endl;  for (i = decay_hooks.lower_bound(hook);       i != decay_hooks.upper_bound(hook); i++) {    if (GetDebugLevel() > 10)      cout << "Calling it for " << (*i).second << endl;    val = ((*i).second)->DoDecay(hook);    if (!val) {      killlist.insert((*i).second);      rval = false;    }  }  if (!rval) {    std::set<Decayer *>::iterator i;    for (i = killlist.begin(); i != killlist.end(); i++) {      UnregisterDecay(*i);    }  }  CleanUp();  return rval;}void Net::SetSumNKeepUpdated(bool keepupdated){  string type;  sumnkeepupdated=keepupdated;  for (size_t i=0; i<nodes.size(); i++) {    type = nodes[i]->GetType();    //if type.compare("SumNV",0,5) ((SumNV *)nodes[i])->SetKeepUpdated(keepupdated);    //else if type.compare("SumN",0,4) ((SumN *)nodes[i])->SetKeepUpdated(keepupdated);    if (type=="SumN") ((SumN *)nodes[i])->SetKeepUpdated(keepupdated);    if (type=="SumNV") ((SumNV *)nodes[i])->SetKeepUpdated(keepupdated);  }}void Net::Save(NetSaver *saver){  CleanUp();  SortNodes();  saver->StartNet(nodes.size() + 1, "net");  saver->StartNamedCont("header");  saver->StartEnumCont(variables.size(), "variables");  for (size_t i=0; i<variables.size(); i++)    saver->SetLabel(variables[i]->GetLabel());  saver->FinishEnumCont("variables");  saver->SetNamedInt("t", (int)t);  saver->SetNamedInt("labelconst", labelconst);  saver->SetNamedInt("debuglevel", debuglevel);  saver->SetNamedInt("node_num", (int)nodes.size());  saver->SetNamedDouble("oldcost", oldcost);  saver->StartNamedCont("decaycounter");  dc->Save(saver);  saver->FinishNamedCont("decaycounter");  saver->FinishNamedCont("header");  for (size_t i=0; i<nodes.size(); i++) {    //cout << i << ": " << nodes[i]->GetIdent() << endl;    saver->StartNode(nodes[i]->GetType());    nodes[i]->Save(saver);    saver->FinishNode(nodes[i]->GetType());  }  saver->FinishNet("net");}#ifdef WITH_MATLAB#include "MatlabSaver.h"void Net::SaveToMatFile(string fname, string varname, bool debugsave){  NetSaver *saver = new NetSaver( new MatlabSaver(fname, varname), debugsave);  this->Save(saver);  saver->SaveIt();  delete saver;}#elsevoid Net::SaveToMatFile(string fname, string varname, bool debugsave){  throw MatlabException("No Matlab support in this library version");}#endif  // WITH_MATLABvoid Net::SaveToXMLFile(string fname, bool debugsave){  NetSaver *saver = new NetSaver( new XMLSaver(fname), debugsave );  this->Save(saver);  saver->SaveIt();  delete saver;}void Net::SaveNodeToXMLFile(string fname, Node *node, bool debugsave){  NetSaver *saver = new NetSaver( new XMLSaver(fname), debugsave );  node->Save(saver);  saver->SaveIt();  delete saver;}#ifdef WITH_PYTHON#include "PythonSaver.h"PyObject *Net::SaveToPyObject(bool debugsave){  PyObject *res;  PythonSaver *psaver = new PythonSaver();  NetSaver *saver = new NetSaver(psaver, debugsave);  this->Save(saver);  res = psaver->GetTheNet();  delete saver;  //delete psaver;  return res;}#endif  // WITH_PYTHONLabel Net::GetNextLabel(Label label){  map<Label, Node *>::iterator it = nodeindex.find(label);  if (it == nodeindex.end())    return label;  int i=0;  Label newlabel = label;  ostringstream ss;  ss << label << '_' << i;  newlabel = ss.str();  while (nodeindex.find(newlabel) != nodeindex.end()) {    i = labelconst*i + 1 + (int)(((double)labelconst)*rand()/(RAND_MAX+1.0));    ostringstream ss;    ss << label << '_' << i;    newlabel = ss.str();  }  return newlabel;}bool Net::ConnectProxies(){  bool retval = true;  for (size_t i=proxies.size(); i>0; i--)    retval &= proxies[i-1]->CheckRef();  return retval;}#ifdef WITH_MATLAB#include "MatlabLoader.h"Net * LoadFromMatFile( string fname, string varname ){  NetLoader * loader = 0;  try {    loader = new NetLoader(new MatlabLoader( fname, varname ) );    loader->LoadIt();  }  catch (...) {    if (loader)      delete loader;    throw;  }  Net       * mynet = new Net( loader );  delete loader;  return mynet;}#elseNet * LoadFromMatFile( string fname, string varname ){  throw MatlabException("No Matlab support in this library version");}#endif // WITH_MATLAB/************************************  Load from file*/Net::Net(NetLoader *loader){  size_t            node_num = 0;  string            type;  map<string,int>   typeNum;          // Map of node types  Node            * newnode;  int temp;  ostringstream msg;  // Initialize the node type map  typeNum[ "Constant" ]     = cn_Constant;  typeNum[ "ConstantV" ]    = cn_ConstantV;  typeNum[ "Prod" ]         = cn_Prod;  typeNum[ "ProdV" ]        = cn_ProdV;  typeNum[ "Sum2" ]         = cn_Sum2;  typeNum[ "Sum2V" ]        = cn_Sum2V;  typeNum[ "SumN" ]         = cn_SumN;  typeNum[ "SumNV" ]        = cn_SumNV;  typeNum[ "Rectification" ] = cn_Rectification;  typeNum[ "RectificationV" ] = cn_RectificationV;  typeNum[ "DelayV" ]       = cn_DelayV;  typeNum[ "Gaussian" ]     = cn_Gaussian;  typeNum[ "GaussianV" ]    = cn_GaussianV;  typeNum[ "DelayGaussV" ]  = cn_DelayGaussV;  typeNum[ "SparseGaussV" ] = cn_SparseGaussV;  typeNum[ "RectifiedGaussian" ] = cn_RectifiedGaussian;  typeNum[ "RectifiedGaussianV" ] = cn_RectifiedGaussianV;  typeNum[ "GaussRect" ] = cn_GaussRect;  typeNum[ "GaussRectV" ] = cn_GaussRectV;  typeNum[ "GaussNonlin" ]  = cn_GaussNonlin;  typeNum[ "GaussNonlinV" ] = cn_GaussNonlinV;  typeNum[ "MoG" ] = cn_MoG;  typeNum[ "MoGV" ] = cn_MoGV;  typeNum[ "Discrete" ]     = cn_Discrete;  typeNum[ "DiscreteV" ]    = cn_DiscreteV;  typeNum[ "DiscreteDirichlet" ] = cn_DiscreteDirichlet;  typeNum[ "DiscreteDirichletV" ] = cn_DiscreteDirichletV;  typeNum[ "Dirichlet" ] = cn_Dirichlet;  typeNum[ "Proxy" ]        = cn_Proxy;  typeNum[ "Relay" ]        = cn_Relay;  typeNum[ "Evidence" ]     = cn_Evidence;  typeNum[ "EvidenceV" ]    = cn_EvidenceV;  typeNum[ "Memory" ]       = cn_Memory;  typeNum[ "OLDelayS" ]     = cn_OLDelayS;  typeNum[ "OLDelayD" ]     = cn_OLDelayD;  loader->StartNet("net");  loader->StartNamedCont("header");  loader->StartEnumCont("variables");  loader->FinishEnumCont("variables");  loader->GetNamedInt("t", temp);  t = temp;  loader->GetNamedInt("labelconst", labelconst);  loader->GetNamedInt("debuglevel", debuglevel);  loader->GetNamedInt("node_num", temp);  node_num = temp;  loader->GetNamedDouble("oldcost", oldcost);  loader->StartNamedCont("decaycounter");  dc = DecayCounter::GlobalLoader(loader);  loader->FinishNamedCont("decaycounter");  loader->FinishNamedCont("header");  activetimeindexgroup = NULL;  for (size_t i=0; i<node_num; i++) {    loader->StartNode(type);    switch( typeNum[ type ] ) {    case  0: // Doesn't exist      msg << "Unknown node in save file, claims to be " << type;      throw TypeException(msg.str());      break;    case cn_Constant:      newnode = new Constant( this, loader ); break;    case cn_ConstantV:      newnode = new ConstantV( this, loader ); break;    case cn_Prod:      newnode = new Prod( this, loader ); break;    case cn_ProdV:      newnode = new ProdV( this, loader ); break;    case cn_Sum2:      newnode = new Sum2( this, loader ); break;    case cn_Sum2V:      newnode = new Sum2V( this, loader ); break;    case cn_SumN:      newnode = new SumN( this, loader ); break;    case cn_SumNV:      newnode = new SumNV( this, loader ); break;    case cn_Rectification:      newnode = new Rectification( this, loader ); break;    case cn_RectificationV:      newnode = new RectificationV( this, loader ); break;    case cn_DelayV:      newnode = new DelayV( this, loader ); break;    case cn_Gaussian:      newnode = new Gaussian( this, loader ); break;    case cn_GaussianV:      newnode = new GaussianV( this, loader ); break;    case cn_DelayGaussV:      newnode = new DelayGaussV( this, loader ); break;    case cn_SparseGaussV:      newnode = new SparseGaussV( this, loader ); break;    case cn_RectifiedGaussian:      newnode = new RectifiedGaussian( this, loader ); break;    case cn_RectifiedGaussianV:      newnode = new RectifiedGaussianV( this, loader ); break;    case cn_GaussRect:      newnode = new GaussRect( this, loader ); break;    case cn_GaussRectV:      newnode = new GaussRectV( this, loader ); break;    case cn_GaussNonlin:      newnode = new GaussNonlin( this, loader ); break;    case cn_GaussNonlinV:      newnode = new GaussNonlinV( this, loader ); break;    case cn_MoG:      newnode = new MoG( this, loader ); break;    case cn_MoGV:      newnode = new MoGV( this, loader ); break;    case cn_Discrete:      newnode = new Discrete( this, loader ); break;    case cn_DiscreteV:      newnode = new DiscreteV( this, loader ); break;    case cn_DiscreteDirichlet:      newnode = new DiscreteDirichlet( this, loader); break;    case cn_DiscreteDirichletV:      newnode = new DiscreteDirichletV( this, loader); break;    case cn_Dirichlet:      newnode = new Dirichlet( this, loader); break;    case cn_Proxy:      newnode = new Proxy( this, loader ); break;    case cn_Relay:      newnode = new Relay( this, loader ); break;    case cn_Evidence:      newnode = new Evidence( this, loader ); break;    case cn_EvidenceV:      newnode = new EvidenceV( this, loader ); break;    case cn_Memory:      newnode = new Memory( this, loader ); break;    case cn_OLDelayS:      newnode = new OLDelayS( this, loader ); break;    case cn_OLDelayD:      newnode = new OLDelayD( this, loader ); break;    default:      msg << "Net::Net(loader): No handler for node type " << type;      throw TypeException(msg.str());    }        loader->FinishNode(type, bind1st(mem_fun(&Node::ReallyAddParent), newnode));    //cout << i << ": " << newnode->GetIdent() << endl;  }  loader->FinishNet("net");  this->ConnectProxies();} // End of "Net::Net"#ifdef WITH_PYTHON#include "PythonLoader.h"Net *CreateNetFromPyObject(PyObject *obj){  NetLoader *loader = new NetLoader(new PythonLoader(obj));  Net *mynet = new Net(loader);  delete loader;  return mynet;}#endif  // WITH_PYTHON

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -