⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 net.cc

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 CC
📖 第 1 页 / 共 2 页
字号:
//// This file is a part of the Bayes Blocks library//// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander// Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.//// This program is free software; you can redistribute it and/or modify// it under the terms of the GNU General Public License as published by// the Free Software Foundation; either version 2, or (at your option)// any later version.//// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the// GNU General Public License (included in file License.txt in the// program package) for more details.//// $Id: Net.cc 7 2006-10-26 10:26:41Z ah $#include "config.h"#ifdef WITH_PYTHON#include <Python.h>#define __PYTHON_H_INCLUDED__#endif#include "Templates.h"#include "Net.h"#include "Node.h"#include "Saver.h"#include "XMLSaver.h"#include "Loader.h"#include "DecayCounter.h"#include <iostream>#include <sstream>#include <stdlib.h>#include <math.h>#include <set>#include <stack>#include <utility>#include <algorithm>#include <functional>Net::Net(size_t ti){  t = ti;  labelconst = 2;  oldcost = 0;  debuglevel = 0;  dc = new TraditionalDecayCounter();  activetimeindexgroup = NULL;}Net::~Net(){  for (int i = nodes.size() - 1; i >= 0; i--) {    delete nodes[i];  }  delete dc;  map<Label, VariableVector *>::iterator iter;  for (iter = variablegroups.begin(); iter != variablegroups.end(); iter++) {    VariableVector *vars = iter->second;    delete vars;  }  map<Label, IntV *>::iterator iter2;  for (iter2 = timeindexgroups.begin();        iter2 != timeindexgroups.end(); iter2++) {    IntV *group = iter2->second;    delete group;  }}double Net::Cost(){  double c = oldcost;  for (size_t i = 0; i < variables.size(); i++) {    c += variables[i]->Cost();  }  return c;}void Net::CleanUp(){  Node *ptr;  while (!deadnodes.empty()) {    ptr = deadnodes.back();    RemovePtr(ptr);    // ptr->~Node();    delete ptr;  }}void Net::AddVariable(Variable *ptr, Label label) {  variables.push_back(ptr);  variableindex[label] = ptr;  AddVariableToGroups(ptr);}void Net::RemovePtr(Node *ptr){  nodes.erase(remove(nodes.begin(), nodes.end(), ptr),	      nodes.end());  variables.erase(remove(variables.begin(), variables.end(), ptr),		  variables.end());  deadnodes.erase(remove(deadnodes.begin(), deadnodes.end(), ptr),		  deadnodes.end());  nodeindex.erase(ptr->GetLabel());  variableindex.erase(ptr->GetLabel());  Decayer *d = dynamic_cast<Decayer *>(ptr);  if (d) {    UnregisterDecay(d);  }  Variable *v = dynamic_cast<Variable *>(ptr);  if (v) {    RemoveVariableFromGroups(v);  }}void Net::UpdateAll(){  CleanUp();  if (activetimeindexgroup == NULL) {    VariableRIterator i = variables.rbegin();    while (i != variables.rend()) {      (*i++)->Update();    }  } else {    VariableRIterator i = variables.rbegin();    while (i != variables.rend()) {      (*i++)->PartialUpdate(activetimeindexgroup);    }  }  ProcessDecayHook("UpdateAll");}void Net::UpdateTimeInd(){  CleanUp();  for (size_t i = variables.size(); i > 0; i--)    if (variables[i-1]->TimeType() == 0)      variables[i-1]->Update();  ProcessDecayHook("UpdateTimeInd");}bool Net::HasVariableGroup(Label group){  return variablegroups.find(group) != variablegroups.end();}void Net::DefineVariableGroup(Label group){  if (HasVariableGroup(group)) {    throw StructureException("Group allready exists");  }  VariableVector *vars = new VariableVector();  VariableIterator iter;  for (iter = variables.begin(); iter != variables.end(); iter++) {#ifdef BROKEN_STRING    if ((*iter)->GetLabel().compare(group, 0, group.size()) == 0) {      vars->push_back(*iter);    }#else    if ((*iter)->GetLabel().compare(0, group.size(), group) == 0) {      vars->push_back(*iter);    }#endif  }  variablegroups[group] = vars;}void Net::RemoveVariableFromGroups(Variable *ptr){  map<Label, VariableVector *>::iterator iter;  for (iter = variablegroups.begin(); iter != variablegroups.end(); iter++) {    VariableVector *vars = iter->second;    vars->erase(remove(vars->begin(), vars->end(), ptr), vars->end());  }}void Net::AddVariableToGroups(Variable *ptr){  map<Label, VariableVector *>::iterator iter;  for (iter = variablegroups.begin(); iter != variablegroups.end(); iter++) {    const Label &group = iter->first;    VariableVector *vars = iter->second;#ifdef BROKEN_STRING    if (ptr->GetLabel().compare(group, 0, group.length()) == 0) {      vars->push_back(ptr);    }#else    if (ptr->GetLabel().compare(0, group.length(), group) == 0) {      vars->push_back(ptr);    }#endif  }}size_t Net::NumGroupVariables(Label group){  if (HasVariableGroup(group)) {    return variablegroups[group]->size();  } else {    throw StructureException("No such group.");  }}Variable *Net::GetGroupVariable(Label group, size_t index){  if (index < NumGroupVariables(group)) {    return (*variablegroups[group])[index];  } else {    throw StructureException("Index out of range.");  }}/* Updates the variables in the group. */void Net::UpdateGroup(Label group){  if (variablegroups.find(group) == variablegroups.end()) {    throw StructureException("Group doesn't exist");  }  CleanUp();  VariableVector *g = variablegroups[group];  VariableIterator iter;  if (activetimeindexgroup == NULL) {    for (iter = g->begin(); iter != g->end(); iter++) {      (*iter)->Update();    }  } else {    for (iter = g->begin(); iter != g->end(); iter++) {      (*iter)->PartialUpdate(activetimeindexgroup);    }  }}/* Calculates the cost arising from the group. */double Net::CostGroup(Label group){  if (variablegroups.find(group) == variablegroups.end()) {    throw StructureException("Group doesn't exist");  }  VariableVector *g = variablegroups[group];  double c = oldcost;  for (VariableIterator iter = g->begin();        iter != g->end(); iter++) {    c += (*iter)->Cost();  }  return c;}bool Net::HasTimeIndexGroup(Label group){  return timeindexgroups.find(group) != timeindexgroups.end();}void Net::DefineTimeIndexGroup(Label group, IntV &indices){  if (HasTimeIndexGroup(group)) {    throw StructureException("Time index group allready exists");  }  // We have to copy indices because it might get freed by Python's GC  IntV *ind = new IntV(indices.size());  copy(indices.begin(), indices.end(), ind->begin());  timeindexgroups[group] = ind;}void Net::EnableTimeIndexGroup(Label group){  if (activetimeindexgroup != NULL) {    throw StructureException("Some time index group is allready active");  }  if (!HasTimeIndexGroup(group)) {    throw StructureException("No such group");  }    activetimeindexgroup = timeindexgroups[group];}void Net::DisableTimeIndexGroup(Label group){  if (activetimeindexgroup == NULL       || !HasTimeIndexGroup(group)      || timeindexgroups[group] != activetimeindexgroup) {    throw StructureException("Time index group is not active");  }  activetimeindexgroup = NULL;}void Net::UpdateTimeDep(){  CleanUp();  for (size_t i = variables.size(); i > 0; i--)    if (variables[i-1]->TimeType() == 1)      variables[i-1]->Update();  ProcessDecayHook("UpdateTimeDep");}void Net::SaveAllStates() {  for (size_t i = variables.size(); i > 0; i--)    variables[i-1]->SaveState();}void Net::SaveAllSteps() {  for (size_t i = variables.size(); i > 0; i--)    variables[i-1]->SaveStep();}void Net::RepeatAllSteps(double alpha) {  for (size_t i = variables.size(); i > 0; i--)    variables[i-1]->RepeatStep(alpha);}void Net::ClearAllStatesAndSteps() {  for (size_t i = variables.size(); i > 0; i--)    variables[i-1]->ClearStateAndStep();}void Net::StepTime(){  CleanUp();  dc->StepTime();  for (size_t i = variables.size(); i > 0; i--)    switch (variables[i-1]->TimeType()) {    case 0:      break;    case 1:      oldcost += variables[i-1]->Cost();      break;    case 2:      variables[i-1]->Update();    }  oldcost *= Decay();  for (size_t i = oldelays.size(); i > 0; i--)    oldelays[i-1]->StepTime();  ProcessDecayHook("StepTime");}void Net::ResetTime(){  dc->StepTime();  for (size_t i = variables.size(); i > 0; i--)    switch (variables[i-1]->TimeType()) {    case 0:      break;    case 1:      oldcost += variables[i-1]->Cost();      break;    case 2:      variables[i-1]->Update();    }  oldcost *= Decay();  for (size_t i = oldelays.size(); i > 0; i--)    oldelays[i-1]->ResetTime();  ProcessDecayHook("ResetTime");}void Net::SetDecayCounter(DecayCounter *d){  if (dc)    free(dc);  dc = d;}double Net::Decay(){  return dc->GetDecay();}// Compare the first element of a pair against a fixed valuetemplate <class _Pair>struct Compare1st : public unary_function<_Pair, bool>{  Compare1st(typename _Pair::first_type __x): myval(__x) { }  bool operator()(const _Pair& __x) const {    return myval == __x.first;  };private:  typename _Pair::first_type myval;};// Topologically sort the nodesvoid Net::SortNodes(){  vector<Node *> sortednodes;  std::set<Node *> addednodes;  vector< pair<Node *, size_t> > dfs_stack;  Node *curnode;  Node *parent = 0;  size_t parind;  // Reserve space for sorted sequence of nodes  sortednodes.reserve(nodes.size());  // Go through the list of nodes in the order they are in the vector  // (in most cases they should already be sorted...)  for (size_t i = 0; i < nodes.size(); i++) {    curnode = nodes[i];    // This node has already found its place    if (addednodes.count(curnode))      continue;    // Add the node to the stack of things to process    dfs_stack.push_back(pair<Node *, size_t>(curnode, 0));    // While there are nodes in the depth first search stack...    while (!dfs_stack.empty()) {      // Check the current node and index of the parent to process      curnode = dfs_stack.back().first;      parind = dfs_stack.back().second;      // Proxies have no real parents and so we skip the checking of parents      // While there are more parents      if ((curnode->GetType() != "Proxy") &&	  (parent = curnode->GetParent(parind)) != 0) {	// Increment the index of parent to process next	dfs_stack.back().second++;	// If parent has not already been processed	if (!(addednodes.count(parent))) {	  // Check that we haven't already passed through the parent	  // and push it to the stack to be processed next	  if ((find_if(dfs_stack.begin(), dfs_stack.end(),		       Compare1st< pair<Node *, size_t> >(parent))	       == dfs_stack.end())) {	    dfs_stack.push_back(pair<Node *, size_t>(parent, 0));	  }	  // We've been here before.  This really shouldn't happen	  // if the graph is a DAG.	  else {	    throw StructureException();	  }	}      }      // All parents have been accounted for, so this is the place for      // our brave little node.      else {	sortednodes.push_back(curnode);	(void)addednodes.insert(curnode);	dfs_stack.pop_back();      }    }  }  // And we're done.  nodes = sortednodes;}void Net::CheckStructure(){  std::set<Node *> checkednodes;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -