📄 main.cc
字号:
//// This file is a part of the Bayes Blocks library//// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander// Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.//// This program is free software; you can redistribute it and/or modify// it under the terms of the GNU General Public License as published by// the Free Software Foundation; either version 2, or (at your option)// any later version.//// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the// GNU General Public License (included in file License.txt in the// program package) for more details.//// $Id: main.cc 7 2006-10-26 10:26:41Z ah $#include <sstream>#include "Net.h"#include "NodeFactory.h"//#include "Node.h"double c_randn(){ double d = -6; for (int i = 0; i < 12; i++) d += drand48(); return d;}int main()try { int xdim = 50, tdim = 100, sdim = 2, ix, it, is, iter = 0, i; Net *net = new Net(tdim); NodeFactory *fact = new NodeFactory(net); Constant *const0 = fact->GetConstant("const0", 0); Gaussian *mvs = fact->GetGaussian("mvs", const0, const0); Gaussian *vvs = fact->GetGaussian("vvs", const0, const0); Gaussian *mvx = fact->GetGaussian("mvx", const0, const0); Gaussian *vvx = fact->GetGaussian("vvx", const0, const0); vector<Gaussian *> vs; vector<Gaussian *> vs0; vector<Gaussian *> vx; vector<Gaussian *> mx; vector<Gaussian *> b0; vector<Gaussian *> b1; vector<ProdV *> prb1; vector<DelayV *> del1; vector<Proxy *> prx1; vector<Gaussian *> a; vector<DelayGaussV *> s; vector<ProdV *> pr; vector<Sum2V *> su; vector<GaussianV *> x; for (is = 0; is < sdim; is++) vs.push_back(fact->GetGaussian("vs", mvs, vvs)); for (is = 0; is < sdim; is++) vs0.push_back(fact->GetGaussian("vs0", const0, const0)); for (ix = 0; ix < xdim; ix++) vx.push_back(fact->GetGaussian("vx", mvx, vvx)); for (ix = 0; ix < xdim; ix++) mx.push_back(fact->GetGaussian("mx", const0, const0)); for (is = 0; is < sdim; is++) b0.push_back(fact->GetGaussian("b0", const0, const0)); for (is = 0; is < sdim; is++) b1.push_back(fact->GetGaussian("b1", const0, const0)); for (is = 0; is < sdim; is++) { ostringstream ss; ss << "prb1(" << ((is + sdim - 1) % sdim) << ")"; prx1.push_back(fact->GetProxy("prx1", ss.str())); del1.push_back(fact->GetDelayV("del1", const0, prx1[is])); } for (is = 0; is < sdim; is++) for (ix = 0; ix < xdim; ix++) a.push_back(fact->GetGaussian("a", const0, const0)); for (is = 0; is < sdim; is++) { ostringstream ss; ss << "prb1(" << is << ")"; s.push_back(fact->GetDelayGaussV("s", del1[is], vs[is], b0[is], const0, vs0[is])); prb1.push_back(fact->GetProdV(ss.str(), s[is], b1[is])); } net->ConnectProxies(); for (ix = 0; ix < xdim; ix++) for (is = 0; is < sdim; is++) pr.push_back(fact->GetProdV("pr", a[is * xdim + ix], s[is])); for (ix = 0; ix < xdim; ix++) { vector<Node *> tmp; for (is = 0; is < sdim; is++) tmp.push_back(pr[ix * sdim + is]); tmp.push_back(mx[ix]); while (tmp.size() > 1) { Sum2V *tmp2 = fact->GetSum2V("su", tmp[0], tmp[1]); tmp.erase(tmp.begin()); tmp.erase(tmp.begin()); tmp.push_back(tmp2); su.push_back(tmp2); } x.push_back(fact->GetGaussianV("x", tmp[0], vx[ix])); } for (ix = 0; ix < xdim; ix++) { DV tmp(tdim); for (it = 0; it < tdim; it++) tmp[it] = sin((double)it+ix)+c_randn()*0.1; x[ix]->Clamp(tmp); } for (i = 0; i < sdim * xdim; i++) a[i]->Clamp(drand48() - 0.5); cout << iter << ": " << net->Cost() << '\n'; while (iter < 3) { net->UpdateAll(); cout << ++iter << ": " << net->Cost() << '\n'; } for (i = 0; i < sdim * xdim; i++) a[i]->Unclamp(); try { net->SaveToMatFile("mynet.mat", "mynet"); } catch(MatlabException e) { cout << "Error in saving: " << e.what(); } while (iter < 20) { net->UpdateAll(); cout << ++iter << ": " << net->Cost() << '\n'; cout.flush(); } /* cout << "\nVX:\n"; for (ix = 0; ix < xdim; ix++) cout << 1/vx[ix]->GetExp() << ' '; cout << "\n\nVS:\n"; for (is = 0; is < sdim; is++) cout << 1/vs[is]->GetExp() << ' '; cout << "\n\nA: average var\n"; for (is = 0; is < sdim; is++) { double d = 0; for (ix = 0; ix < xdim; ix++) d += a[is*xdim + ix]->GetVar(); cout << d / xdim << ' '; } cout << "\n\nA:\n"; for (is = 0; is < sdim; is++) { for (ix = 0; ix < xdim; ix++) cout << a[is*xdim + ix]->GetMean() << ' '; cout << '\n'; } */ delete fact; delete net; return 0;} catch (std::runtime_error &e) { cerr << "Runtime error: " << e.what() << endl; } catch (std::logic_error &e) { cerr << "Logic error: " << e.what() << endl; } catch (std::exception &e) { cerr << "Exception: " << e.what() << endl; } catch (...) { cerr << "Unknown exception." << endl; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -