📄 net.cc
字号:
//// This file is a part of the Bayes Blocks library//// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander// Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.//// This program is free software; you can redistribute it and/or modify// it under the terms of the GNU General Public License as published by// the Free Software Foundation; either version 2, or (at your option)// any later version.//// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the// GNU General Public License (included in file License.txt in the// program package) for more details.//// $Id: Net.cc 7 2006-10-26 10:26:41Z ah $#include "config.h"#ifdef WITH_PYTHON#include <Python.h>#define __PYTHON_H_INCLUDED__#endif#include "Templates.h"#include "Net.h"#include "Node.h"#include "Saver.h"#include "XMLSaver.h"#include "Loader.h"#include "DecayCounter.h"#include <iostream>#include <sstream>#include <stdlib.h>#include <math.h>#include <set>#include <stack>#include <utility>#include <algorithm>#include <functional>Net::Net(size_t ti){ t = ti; labelconst = 2; oldcost = 0; debuglevel = 0; dc = new TraditionalDecayCounter(); activetimeindexgroup = NULL;}Net::~Net(){ for (int i = nodes.size() - 1; i >= 0; i--) { delete nodes[i]; } delete dc; map<Label, VariableVector *>::iterator iter; for (iter = variablegroups.begin(); iter != variablegroups.end(); iter++) { VariableVector *vars = iter->second; delete vars; } map<Label, IntV *>::iterator iter2; for (iter2 = timeindexgroups.begin(); iter2 != timeindexgroups.end(); iter2++) { IntV *group = iter2->second; delete group; }}double Net::Cost(){ double c = oldcost; for (size_t i = 0; i < variables.size(); i++) { c += variables[i]->Cost(); } return c;}void Net::CleanUp(){ Node *ptr; while (!deadnodes.empty()) { ptr = deadnodes.back(); RemovePtr(ptr); // ptr->~Node(); delete ptr; }}void Net::AddVariable(Variable *ptr, Label label) { variables.push_back(ptr); variableindex[label] = ptr; AddVariableToGroups(ptr);}void Net::RemovePtr(Node *ptr){ nodes.erase(remove(nodes.begin(), nodes.end(), ptr), nodes.end()); variables.erase(remove(variables.begin(), variables.end(), ptr), variables.end()); deadnodes.erase(remove(deadnodes.begin(), deadnodes.end(), ptr), deadnodes.end()); nodeindex.erase(ptr->GetLabel()); variableindex.erase(ptr->GetLabel()); Decayer *d = dynamic_cast<Decayer *>(ptr); if (d) { UnregisterDecay(d); } Variable *v = dynamic_cast<Variable *>(ptr); if (v) { RemoveVariableFromGroups(v); }}void Net::UpdateAll(){ CleanUp(); if (activetimeindexgroup == NULL) { VariableRIterator i = variables.rbegin(); while (i != variables.rend()) { (*i++)->Update(); } } else { VariableRIterator i = variables.rbegin(); while (i != variables.rend()) { (*i++)->PartialUpdate(activetimeindexgroup); } } ProcessDecayHook("UpdateAll");}void Net::UpdateTimeInd(){ CleanUp(); for (size_t i = variables.size(); i > 0; i--) if (variables[i-1]->TimeType() == 0) variables[i-1]->Update(); ProcessDecayHook("UpdateTimeInd");}bool Net::HasVariableGroup(Label group){ return variablegroups.find(group) != variablegroups.end();}void Net::DefineVariableGroup(Label group){ if (HasVariableGroup(group)) { throw StructureException("Group allready exists"); } VariableVector *vars = new VariableVector(); VariableIterator iter; for (iter = variables.begin(); iter != variables.end(); iter++) {#ifdef BROKEN_STRING if ((*iter)->GetLabel().compare(group, 0, group.size()) == 0) { vars->push_back(*iter); }#else if ((*iter)->GetLabel().compare(0, group.size(), group) == 0) { vars->push_back(*iter); }#endif } variablegroups[group] = vars;}void Net::RemoveVariableFromGroups(Variable *ptr){ map<Label, VariableVector *>::iterator iter; for (iter = variablegroups.begin(); iter != variablegroups.end(); iter++) { VariableVector *vars = iter->second; vars->erase(remove(vars->begin(), vars->end(), ptr), vars->end()); }}void Net::AddVariableToGroups(Variable *ptr){ map<Label, VariableVector *>::iterator iter; for (iter = variablegroups.begin(); iter != variablegroups.end(); iter++) { const Label &group = iter->first; VariableVector *vars = iter->second;#ifdef BROKEN_STRING if (ptr->GetLabel().compare(group, 0, group.length()) == 0) { vars->push_back(ptr); }#else if (ptr->GetLabel().compare(0, group.length(), group) == 0) { vars->push_back(ptr); }#endif }}size_t Net::NumGroupVariables(Label group){ if (HasVariableGroup(group)) { return variablegroups[group]->size(); } else { throw StructureException("No such group."); }}Variable *Net::GetGroupVariable(Label group, size_t index){ if (index < NumGroupVariables(group)) { return (*variablegroups[group])[index]; } else { throw StructureException("Index out of range."); }}/* Updates the variables in the group. */void Net::UpdateGroup(Label group){ if (variablegroups.find(group) == variablegroups.end()) { throw StructureException("Group doesn't exist"); } CleanUp(); VariableVector *g = variablegroups[group]; VariableIterator iter; if (activetimeindexgroup == NULL) { for (iter = g->begin(); iter != g->end(); iter++) { (*iter)->Update(); } } else { for (iter = g->begin(); iter != g->end(); iter++) { (*iter)->PartialUpdate(activetimeindexgroup); } }}/* Calculates the cost arising from the group. */double Net::CostGroup(Label group){ if (variablegroups.find(group) == variablegroups.end()) { throw StructureException("Group doesn't exist"); } VariableVector *g = variablegroups[group]; double c = oldcost; for (VariableIterator iter = g->begin(); iter != g->end(); iter++) { c += (*iter)->Cost(); } return c;}bool Net::HasTimeIndexGroup(Label group){ return timeindexgroups.find(group) != timeindexgroups.end();}void Net::DefineTimeIndexGroup(Label group, IntV &indices){ if (HasTimeIndexGroup(group)) { throw StructureException("Time index group allready exists"); } // We have to copy indices because it might get freed by Python's GC IntV *ind = new IntV(indices.size()); copy(indices.begin(), indices.end(), ind->begin()); timeindexgroups[group] = ind;}void Net::EnableTimeIndexGroup(Label group){ if (activetimeindexgroup != NULL) { throw StructureException("Some time index group is allready active"); } if (!HasTimeIndexGroup(group)) { throw StructureException("No such group"); } activetimeindexgroup = timeindexgroups[group];}void Net::DisableTimeIndexGroup(Label group){ if (activetimeindexgroup == NULL || !HasTimeIndexGroup(group) || timeindexgroups[group] != activetimeindexgroup) { throw StructureException("Time index group is not active"); } activetimeindexgroup = NULL;}void Net::UpdateTimeDep(){ CleanUp(); for (size_t i = variables.size(); i > 0; i--) if (variables[i-1]->TimeType() == 1) variables[i-1]->Update(); ProcessDecayHook("UpdateTimeDep");}void Net::SaveAllStates() { for (size_t i = variables.size(); i > 0; i--) variables[i-1]->SaveState();}void Net::SaveAllSteps() { for (size_t i = variables.size(); i > 0; i--) variables[i-1]->SaveStep();}void Net::RepeatAllSteps(double alpha) { for (size_t i = variables.size(); i > 0; i--) variables[i-1]->RepeatStep(alpha);}void Net::ClearAllStatesAndSteps() { for (size_t i = variables.size(); i > 0; i--) variables[i-1]->ClearStateAndStep();}void Net::StepTime(){ CleanUp(); dc->StepTime(); for (size_t i = variables.size(); i > 0; i--) switch (variables[i-1]->TimeType()) { case 0: break; case 1: oldcost += variables[i-1]->Cost(); break; case 2: variables[i-1]->Update(); } oldcost *= Decay(); for (size_t i = oldelays.size(); i > 0; i--) oldelays[i-1]->StepTime(); ProcessDecayHook("StepTime");}void Net::ResetTime(){ dc->StepTime(); for (size_t i = variables.size(); i > 0; i--) switch (variables[i-1]->TimeType()) { case 0: break; case 1: oldcost += variables[i-1]->Cost(); break; case 2: variables[i-1]->Update(); } oldcost *= Decay(); for (size_t i = oldelays.size(); i > 0; i--) oldelays[i-1]->ResetTime(); ProcessDecayHook("ResetTime");}void Net::SetDecayCounter(DecayCounter *d){ if (dc) free(dc); dc = d;}double Net::Decay(){ return dc->GetDecay();}// Compare the first element of a pair against a fixed valuetemplate <class _Pair>struct Compare1st : public unary_function<_Pair, bool>{ Compare1st(typename _Pair::first_type __x): myval(__x) { } bool operator()(const _Pair& __x) const { return myval == __x.first; };private: typename _Pair::first_type myval;};// Topologically sort the nodesvoid Net::SortNodes(){ vector<Node *> sortednodes; std::set<Node *> addednodes; vector< pair<Node *, size_t> > dfs_stack; Node *curnode; Node *parent = 0; size_t parind; // Reserve space for sorted sequence of nodes sortednodes.reserve(nodes.size()); // Go through the list of nodes in the order they are in the vector // (in most cases they should already be sorted...) for (size_t i = 0; i < nodes.size(); i++) { curnode = nodes[i]; // This node has already found its place if (addednodes.count(curnode)) continue; // Add the node to the stack of things to process dfs_stack.push_back(pair<Node *, size_t>(curnode, 0)); // While there are nodes in the depth first search stack... while (!dfs_stack.empty()) { // Check the current node and index of the parent to process curnode = dfs_stack.back().first; parind = dfs_stack.back().second; // Proxies have no real parents and so we skip the checking of parents // While there are more parents if ((curnode->GetType() != "Proxy") && (parent = curnode->GetParent(parind)) != 0) { // Increment the index of parent to process next dfs_stack.back().second++; // If parent has not already been processed if (!(addednodes.count(parent))) { // Check that we haven't already passed through the parent // and push it to the stack to be processed next if ((find_if(dfs_stack.begin(), dfs_stack.end(), Compare1st< pair<Node *, size_t> >(parent)) == dfs_stack.end())) { dfs_stack.push_back(pair<Node *, size_t>(parent, 0)); } // We've been here before. This really shouldn't happen // if the graph is a DAG. else { throw StructureException(); } } } // All parents have been accounted for, so this is the place for // our brave little node. else { sortednodes.push_back(curnode); (void)addednodes.insert(curnode); dfs_stack.pop_back(); } } } // And we're done. nodes = sortednodes;}void Net::CheckStructure(){ std::set<Node *> checkednodes;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -