📄 train.cpp
字号:
/** Train a MLP by a supplied GREN error network $Id: train.cpp,v 1.2 2004/06/16 12:32:21 opx Exp $ @author OP*/#include <ctime> #include "annie/args.h"#include "annie/random.h"#include "annie/sys.h"#include "annie/listeners.h"#include "annie/MultiLayerNetwork.h"#include <iomanip>using namespace annie;using namespace std;static NumberParameter npars[] = { { "hidden1", "count of hidden neurons (0 = no first hidden layer) in the trained MLP", 0. }, { "hidden2", "count of hidden neurons in second hidden layer (0=no layer) in the trained MLP", 0. }, DEFAULT_MLP_NPARS { "verbose", "be a bit verbose", 0. }, { "filter", "filter frequent changes", 0. }, { "verboseX", "examples verbose", 0. }, { NULL } };enum { GREN, INPUTS, TRAINED };static StringParameter spars[] = { { "gren", "file to load the trained GREN from", "../data/fromPolar.gren.net" }, { "inputs", "file of training inputs", "../data/fromPolar.inputs.ts" }, { "trained", "file to save the trained net to", "../data/fromPolar.trained.net" }, { NULL } };class TrainByGREN {public: enum { TRAINED_OFFSET = 5 }; /// @param ts contains also the output pairs, but these are only used to compute real error TrainByGREN(MultiLayerNetwork &trainer, TrainingSet &ts, uint hidden1, uint hidden2) : net(ts.getInputSize(), TRAINED_OFFSET), gren(trainer), _ts(ts) { ASSERT(!hidden2 || hidden1); if(hidden1) net.addLayer(hidden1); if(hidden2) net.addLayer(hidden2); uint outs = gren.getInputCount() - _ts.getInputSize(); ASSERT(outs); if(_ts.getOutputSize()) ASSERT(outs == _ts.getOutputSize()); net.addLayer(outs); //out //full connection net.connectLayer(0); if(hidden1) net.connectLayer(1); if(hidden2) net.connectLayer(2); //cout << "initial net: " << net.verbose(); } void train() { net.getErrorGREN(_ts); net.trainGREN(gren, _ts, defaultControl["epochs"] , defaultControl["learningRate"] , defaultControl["momentum"] ); defaultControl["real epoch error"] = -1; defaultControl["normalized real epoch error"] = -1; cout << " resulting error: \n"; net.getErrorGREN(_ts); } void save(string &file) { net.save(file); } const MultiLayerNetwork &getNet() const { return net; };private: MultiLayerNetwork net, &gren; TrainingSet &_ts;};//// init/run stuffint main(int argc, char *argv[]) { try { sysInit(); parseArgs(argc, argv, npars, spars); //srand((unsigned)time(NULL)); srand(1); //const //srand(2); //arith defaultControl["momentum"] = 0; SimpleVisualiser sv; RigidVisualiser rv; ValueUpdateListener *visualiser; if(defaultControl["filter"] ) visualiser = &sv; else visualiser = &rv; defaultControl.addListener(visualiser); cerr.setf(ios::fixed, ios::floatfield); cerr << std::setprecision(6); cout << "preparing" << endl; TrainingSet inputs(spars[INPUTS].value); cout << "loaded inputs: " << (string) inputs << endl; MultiLayerNetwork gren(spars[GREN].value); cout << "loaded GREN: " << (string) gren << endl; cout << gren.getLayer(1)[0] << endl; TrainByGREN tr(gren, inputs, defaultControl["hidden1"] , defaultControl["hidden2"] ); cout << "MLP to train:: " << (string) tr.getNet() << endl; cout << "Training inputs: " << (string) inputs << endl; cout << "training " << endl; tr.train(); cout << "saving to " << spars[TRAINED].value << endl; tr.save(spars[TRAINED].value); defaultControl.triggerAll(); } catch (Exception &e) { cout << ">>> heh?\n" << e.what() << endl; exit(-1); } catch (...) { cout << ">>> heh? unknown exception\n"; exit(-1); } return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -