⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 main.cc

📁 The library is a C++/Python implementation of the variational building block framework introduced in
💻 CC
字号:
//// This file is a part of the Bayes Blocks library//// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander// Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.//// This program is free software; you can redistribute it and/or modify// it under the terms of the GNU General Public License as published by// the Free Software Foundation; either version 2, or (at your option)// any later version.//// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the// GNU General Public License (included in file License.txt in the// program package) for more details.//// $Id: main.cc 7 2006-10-26 10:26:41Z ah $#include <sstream>#include "Net.h"#include "NodeFactory.h"//#include "Node.h"double c_randn(){  double d = -6;  for (int i = 0; i < 12; i++)    d += drand48();  return d;}int main()try {  int xdim = 50, tdim = 100, sdim = 2, ix, it, is, iter = 0, i;  Net *net = new Net(tdim);  NodeFactory *fact = new NodeFactory(net);  Constant *const0 = fact->GetConstant("const0", 0);  Gaussian *mvs = fact->GetGaussian("mvs", const0, const0);  Gaussian *vvs = fact->GetGaussian("vvs", const0, const0);  Gaussian *mvx = fact->GetGaussian("mvx", const0, const0);  Gaussian *vvx = fact->GetGaussian("vvx", const0, const0);  vector<Gaussian *> vs;  vector<Gaussian *> vs0;  vector<Gaussian *> vx;  vector<Gaussian *> mx;  vector<Gaussian *> b0;  vector<Gaussian *> b1;  vector<ProdV *> prb1;  vector<DelayV *> del1;  vector<Proxy *> prx1;  vector<Gaussian *> a;  vector<DelayGaussV *> s;  vector<ProdV *> pr;  vector<Sum2V *> su;  vector<GaussianV *> x;  for (is = 0; is < sdim; is++)    vs.push_back(fact->GetGaussian("vs", mvs, vvs));  for (is = 0; is < sdim; is++)    vs0.push_back(fact->GetGaussian("vs0", const0, const0));  for (ix = 0; ix < xdim; ix++)    vx.push_back(fact->GetGaussian("vx", mvx, vvx));  for (ix = 0; ix < xdim; ix++)    mx.push_back(fact->GetGaussian("mx", const0, const0));  for (is = 0; is < sdim; is++)    b0.push_back(fact->GetGaussian("b0", const0, const0));  for (is = 0; is < sdim; is++)    b1.push_back(fact->GetGaussian("b1", const0, const0));  for (is = 0; is < sdim; is++) {    ostringstream ss;    ss << "prb1(" << ((is + sdim - 1) % sdim) << ")";    prx1.push_back(fact->GetProxy("prx1", ss.str()));    del1.push_back(fact->GetDelayV("del1", const0, prx1[is]));  }  for (is = 0; is < sdim; is++)    for (ix = 0; ix < xdim; ix++)      a.push_back(fact->GetGaussian("a", const0, const0));  for (is = 0; is < sdim; is++) {    ostringstream ss;    ss << "prb1(" << is << ")";    s.push_back(fact->GetDelayGaussV("s", del1[is], vs[is], b0[is], const0, vs0[is]));    prb1.push_back(fact->GetProdV(ss.str(), s[is], b1[is]));  }  net->ConnectProxies();  for (ix = 0; ix < xdim; ix++)    for (is = 0; is < sdim; is++)      pr.push_back(fact->GetProdV("pr", a[is * xdim + ix], s[is]));  for (ix = 0; ix < xdim; ix++) {    vector<Node *> tmp;    for (is = 0; is < sdim; is++)      tmp.push_back(pr[ix * sdim + is]);    tmp.push_back(mx[ix]);    while (tmp.size() > 1) {      Sum2V *tmp2 = fact->GetSum2V("su", tmp[0], tmp[1]);      tmp.erase(tmp.begin());      tmp.erase(tmp.begin());      tmp.push_back(tmp2);      su.push_back(tmp2);    }    x.push_back(fact->GetGaussianV("x", tmp[0], vx[ix]));  }  for (ix = 0; ix < xdim; ix++) {    DV tmp(tdim);    for (it = 0; it < tdim; it++)      tmp[it] = sin((double)it+ix)+c_randn()*0.1;    x[ix]->Clamp(tmp);  }  for (i = 0; i < sdim * xdim; i++)    a[i]->Clamp(drand48() - 0.5);  cout << iter << ": " << net->Cost() << '\n';  while (iter < 3) {    net->UpdateAll();    cout << ++iter << ": " << net->Cost() << '\n';  }  for (i = 0; i < sdim * xdim; i++)    a[i]->Unclamp();  try {    net->SaveToMatFile("mynet.mat", "mynet");  }  catch(MatlabException e) {    cout << "Error in saving: " << e.what();  }  while (iter < 20) {    net->UpdateAll();    cout << ++iter << ": " << net->Cost() << '\n';    cout.flush();  }  /*  cout << "\nVX:\n";  for (ix = 0; ix < xdim; ix++)    cout << 1/vx[ix]->GetExp() << ' ';  cout << "\n\nVS:\n";  for (is = 0; is < sdim; is++)    cout << 1/vs[is]->GetExp() << ' ';  cout << "\n\nA: average var\n";  for (is = 0; is < sdim; is++) {    double d = 0;    for (ix = 0; ix < xdim; ix++)      d += a[is*xdim + ix]->GetVar();    cout << d / xdim << ' ';  }  cout << "\n\nA:\n";  for (is = 0; is < sdim; is++) {    for (ix = 0; ix < xdim; ix++)      cout << a[is*xdim + ix]->GetMean() << ' ';    cout << '\n';  }  */  delete fact;  delete net;  return 0;} catch (std::runtime_error &e) {   cerr << "Runtime error: " << e.what() << endl; } catch (std::logic_error &e) {   cerr << "Logic error: " << e.what() << endl; } catch (std::exception &e) {   cerr << "Exception: " << e.what() << endl; } catch (...) {   cerr << "Unknown exception." << endl; }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -