📄 net.cc
字号:
std::set<Node *> tocheck; std::set<Node *>::iterator curnode; vector<Node *> stack; Node *node; CleanUp(); SortNodes(); for (size_t i = 0; i < variables.size(); i++) { tocheck.insert((Node *)variables[i]); } for (size_t i = 0; i < nodes.size(); i++) { if (nodes[i]->GetType() == "DelayV") { tocheck.insert(nodes[i]); } } for (curnode = tocheck.begin(); curnode != tocheck.end(); ++curnode) { stack.clear(); stack.push_back(*curnode); checkednodes.clear(); checkednodes.insert(*curnode); while (!stack.empty()) { node = stack.back(); stack.pop_back(); for (size_t i = 0; i < node->NumChildren(); i++) { if (checkednodes.find(node->GetChild(i)) != checkednodes.end()) { throw StructureException("Invalid structure of the network. More than one route from " + (*curnode)->GetLabel() + " to " + node->GetChild(i)->GetLabel() + "."); } checkednodes.insert(node->GetChild(i)); if (tocheck.find(node->GetChild(i)) == tocheck.end()) { stack.push_back(node->GetChild(i)); } } } }}bool Net::RegisterDecay(Decayer *d, string hook) { multimap<string, Decayer *>::iterator i; for (i = decay_hooks.lower_bound(hook); i != decay_hooks.upper_bound(hook); i++) { if ((*i).second == d) return false; } if (GetDebugLevel() > 10) cout << "Registering decay " << d << " for hook " << hook << endl; decay_hooks.insert(multimap<string, Decayer *>::value_type(hook, d)); d->RegisterToHook(hook); return true;}bool Net::UnregisterDecay(Decayer *d) { multimap<string, Decayer *>::iterator i; bool retval = false; i = decay_hooks.begin(); while (i != decay_hooks.end()) { if ((*i).second == d) { if (GetDebugLevel() > 10) cout << "Removing decay " << d << " from hook " << (*i).first << endl; d->EraseHook((*i).first); decay_hooks.erase(i); retval = true; i = decay_hooks.begin(); } else i++; } return retval;}bool Net::UnregisterDecayFromHook(Decayer *d, string hook) { multimap<string, Decayer *>::iterator i; for (i = decay_hooks.lower_bound(hook); i != decay_hooks.upper_bound(hook); i++) { if ((*i).second == d) { if (GetDebugLevel() > 10) cout << "Removing decay " << d << " from hook " << hook << endl; d->EraseHook(hook); decay_hooks.erase(i); return true; } } return false;}bool Net::ProcessDecayHook(string hook) { multimap<string, Decayer *>::iterator i; std::set<Decayer *> killlist; bool val, rval = true; if (GetDebugLevel() > 10) cout << "Processing hook " << hook << endl; for (i = decay_hooks.lower_bound(hook); i != decay_hooks.upper_bound(hook); i++) { if (GetDebugLevel() > 10) cout << "Calling it for " << (*i).second << endl; val = ((*i).second)->DoDecay(hook); if (!val) { killlist.insert((*i).second); rval = false; } } if (!rval) { std::set<Decayer *>::iterator i; for (i = killlist.begin(); i != killlist.end(); i++) { UnregisterDecay(*i); } } CleanUp(); return rval;}void Net::SetSumNKeepUpdated(bool keepupdated){ string type; sumnkeepupdated=keepupdated; for (size_t i=0; i<nodes.size(); i++) { type = nodes[i]->GetType(); //if type.compare("SumNV",0,5) ((SumNV *)nodes[i])->SetKeepUpdated(keepupdated); //else if type.compare("SumN",0,4) ((SumN *)nodes[i])->SetKeepUpdated(keepupdated); if (type=="SumN") ((SumN *)nodes[i])->SetKeepUpdated(keepupdated); if (type=="SumNV") ((SumNV *)nodes[i])->SetKeepUpdated(keepupdated); }}void Net::Save(NetSaver *saver){ CleanUp(); SortNodes(); saver->StartNet(nodes.size() + 1, "net"); saver->StartNamedCont("header"); saver->StartEnumCont(variables.size(), "variables"); for (size_t i=0; i<variables.size(); i++) saver->SetLabel(variables[i]->GetLabel()); saver->FinishEnumCont("variables"); saver->SetNamedInt("t", (int)t); saver->SetNamedInt("labelconst", labelconst); saver->SetNamedInt("debuglevel", debuglevel); saver->SetNamedInt("node_num", (int)nodes.size()); saver->SetNamedDouble("oldcost", oldcost); saver->StartNamedCont("decaycounter"); dc->Save(saver); saver->FinishNamedCont("decaycounter"); saver->FinishNamedCont("header"); for (size_t i=0; i<nodes.size(); i++) { //cout << i << ": " << nodes[i]->GetIdent() << endl; saver->StartNode(nodes[i]->GetType()); nodes[i]->Save(saver); saver->FinishNode(nodes[i]->GetType()); } saver->FinishNet("net");}#ifdef WITH_MATLAB#include "MatlabSaver.h"void Net::SaveToMatFile(string fname, string varname, bool debugsave){ NetSaver *saver = new NetSaver( new MatlabSaver(fname, varname), debugsave); this->Save(saver); saver->SaveIt(); delete saver;}#elsevoid Net::SaveToMatFile(string fname, string varname, bool debugsave){ throw MatlabException("No Matlab support in this library version");}#endif // WITH_MATLABvoid Net::SaveToXMLFile(string fname, bool debugsave){ NetSaver *saver = new NetSaver( new XMLSaver(fname), debugsave ); this->Save(saver); saver->SaveIt(); delete saver;}void Net::SaveNodeToXMLFile(string fname, Node *node, bool debugsave){ NetSaver *saver = new NetSaver( new XMLSaver(fname), debugsave ); node->Save(saver); saver->SaveIt(); delete saver;}#ifdef WITH_PYTHON#include "PythonSaver.h"PyObject *Net::SaveToPyObject(bool debugsave){ PyObject *res; PythonSaver *psaver = new PythonSaver(); NetSaver *saver = new NetSaver(psaver, debugsave); this->Save(saver); res = psaver->GetTheNet(); delete saver; //delete psaver; return res;}#endif // WITH_PYTHONLabel Net::GetNextLabel(Label label){ map<Label, Node *>::iterator it = nodeindex.find(label); if (it == nodeindex.end()) return label; int i=0; Label newlabel = label; ostringstream ss; ss << label << '_' << i; newlabel = ss.str(); while (nodeindex.find(newlabel) != nodeindex.end()) { i = labelconst*i + 1 + (int)(((double)labelconst)*rand()/(RAND_MAX+1.0)); ostringstream ss; ss << label << '_' << i; newlabel = ss.str(); } return newlabel;}bool Net::ConnectProxies(){ bool retval = true; for (size_t i=proxies.size(); i>0; i--) retval &= proxies[i-1]->CheckRef(); return retval;}#ifdef WITH_MATLAB#include "MatlabLoader.h"Net * LoadFromMatFile( string fname, string varname ){ NetLoader * loader = 0; try { loader = new NetLoader(new MatlabLoader( fname, varname ) ); loader->LoadIt(); } catch (...) { if (loader) delete loader; throw; } Net * mynet = new Net( loader ); delete loader; return mynet;}#elseNet * LoadFromMatFile( string fname, string varname ){ throw MatlabException("No Matlab support in this library version");}#endif // WITH_MATLAB/************************************ Load from file*/Net::Net(NetLoader *loader){ size_t node_num = 0; string type; map<string,int> typeNum; // Map of node types Node * newnode; int temp; ostringstream msg; // Initialize the node type map typeNum[ "Constant" ] = cn_Constant; typeNum[ "ConstantV" ] = cn_ConstantV; typeNum[ "Prod" ] = cn_Prod; typeNum[ "ProdV" ] = cn_ProdV; typeNum[ "Sum2" ] = cn_Sum2; typeNum[ "Sum2V" ] = cn_Sum2V; typeNum[ "SumN" ] = cn_SumN; typeNum[ "SumNV" ] = cn_SumNV; typeNum[ "Rectification" ] = cn_Rectification; typeNum[ "RectificationV" ] = cn_RectificationV; typeNum[ "DelayV" ] = cn_DelayV; typeNum[ "Gaussian" ] = cn_Gaussian; typeNum[ "GaussianV" ] = cn_GaussianV; typeNum[ "DelayGaussV" ] = cn_DelayGaussV; typeNum[ "SparseGaussV" ] = cn_SparseGaussV; typeNum[ "RectifiedGaussian" ] = cn_RectifiedGaussian; typeNum[ "RectifiedGaussianV" ] = cn_RectifiedGaussianV; typeNum[ "GaussRect" ] = cn_GaussRect; typeNum[ "GaussRectV" ] = cn_GaussRectV; typeNum[ "GaussNonlin" ] = cn_GaussNonlin; typeNum[ "GaussNonlinV" ] = cn_GaussNonlinV; typeNum[ "MoG" ] = cn_MoG; typeNum[ "MoGV" ] = cn_MoGV; typeNum[ "Discrete" ] = cn_Discrete; typeNum[ "DiscreteV" ] = cn_DiscreteV; typeNum[ "DiscreteDirichlet" ] = cn_DiscreteDirichlet; typeNum[ "DiscreteDirichletV" ] = cn_DiscreteDirichletV; typeNum[ "Dirichlet" ] = cn_Dirichlet; typeNum[ "Proxy" ] = cn_Proxy; typeNum[ "Relay" ] = cn_Relay; typeNum[ "Evidence" ] = cn_Evidence; typeNum[ "EvidenceV" ] = cn_EvidenceV; typeNum[ "Memory" ] = cn_Memory; typeNum[ "OLDelayS" ] = cn_OLDelayS; typeNum[ "OLDelayD" ] = cn_OLDelayD; loader->StartNet("net"); loader->StartNamedCont("header"); loader->StartEnumCont("variables"); loader->FinishEnumCont("variables"); loader->GetNamedInt("t", temp); t = temp; loader->GetNamedInt("labelconst", labelconst); loader->GetNamedInt("debuglevel", debuglevel); loader->GetNamedInt("node_num", temp); node_num = temp; loader->GetNamedDouble("oldcost", oldcost); loader->StartNamedCont("decaycounter"); dc = DecayCounter::GlobalLoader(loader); loader->FinishNamedCont("decaycounter"); loader->FinishNamedCont("header"); activetimeindexgroup = NULL; for (size_t i=0; i<node_num; i++) { loader->StartNode(type); switch( typeNum[ type ] ) { case 0: // Doesn't exist msg << "Unknown node in save file, claims to be " << type; throw TypeException(msg.str()); break; case cn_Constant: newnode = new Constant( this, loader ); break; case cn_ConstantV: newnode = new ConstantV( this, loader ); break; case cn_Prod: newnode = new Prod( this, loader ); break; case cn_ProdV: newnode = new ProdV( this, loader ); break; case cn_Sum2: newnode = new Sum2( this, loader ); break; case cn_Sum2V: newnode = new Sum2V( this, loader ); break; case cn_SumN: newnode = new SumN( this, loader ); break; case cn_SumNV: newnode = new SumNV( this, loader ); break; case cn_Rectification: newnode = new Rectification( this, loader ); break; case cn_RectificationV: newnode = new RectificationV( this, loader ); break; case cn_DelayV: newnode = new DelayV( this, loader ); break; case cn_Gaussian: newnode = new Gaussian( this, loader ); break; case cn_GaussianV: newnode = new GaussianV( this, loader ); break; case cn_DelayGaussV: newnode = new DelayGaussV( this, loader ); break; case cn_SparseGaussV: newnode = new SparseGaussV( this, loader ); break; case cn_RectifiedGaussian: newnode = new RectifiedGaussian( this, loader ); break; case cn_RectifiedGaussianV: newnode = new RectifiedGaussianV( this, loader ); break; case cn_GaussRect: newnode = new GaussRect( this, loader ); break; case cn_GaussRectV: newnode = new GaussRectV( this, loader ); break; case cn_GaussNonlin: newnode = new GaussNonlin( this, loader ); break; case cn_GaussNonlinV: newnode = new GaussNonlinV( this, loader ); break; case cn_MoG: newnode = new MoG( this, loader ); break; case cn_MoGV: newnode = new MoGV( this, loader ); break; case cn_Discrete: newnode = new Discrete( this, loader ); break; case cn_DiscreteV: newnode = new DiscreteV( this, loader ); break; case cn_DiscreteDirichlet: newnode = new DiscreteDirichlet( this, loader); break; case cn_DiscreteDirichletV: newnode = new DiscreteDirichletV( this, loader); break; case cn_Dirichlet: newnode = new Dirichlet( this, loader); break; case cn_Proxy: newnode = new Proxy( this, loader ); break; case cn_Relay: newnode = new Relay( this, loader ); break; case cn_Evidence: newnode = new Evidence( this, loader ); break; case cn_EvidenceV: newnode = new EvidenceV( this, loader ); break; case cn_Memory: newnode = new Memory( this, loader ); break; case cn_OLDelayS: newnode = new OLDelayS( this, loader ); break; case cn_OLDelayD: newnode = new OLDelayD( this, loader ); break; default: msg << "Net::Net(loader): No handler for node type " << type; throw TypeException(msg.str()); } loader->FinishNode(type, bind1st(mem_fun(&Node::ReallyAddParent), newnode)); //cout << i << ": " << newnode->GetIdent() << endl; } loader->FinishNet("net"); this->ConnectProxies();} // End of "Net::Net"#ifdef WITH_PYTHON#include "PythonLoader.h"Net *CreateNetFromPyObject(PyObject *obj){ NetLoader *loader = new NetLoader(new PythonLoader(obj)); Net *mynet = new Net(loader); delete loader; return mynet;}#endif // WITH_PYTHON
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -