📄 create.cpp
字号:
/** Train an example error MLP network (GREN network) Error: squared eucledian distance from correct point ( e = \sqrt( (x - x_o)^2 + ( y - y_o)^2) some useful parameter settings: With arirh Hidden layers 4 x 0 --> error 0.033 4 x 2 --> 0.0220 $Id: create.cpp,v 1.2 2004/06/16 12:32:21 opx Exp $ @author OP*/#include "annie/args.h"#include "annie/random.h"#include "annie/sys.h"#include "annie/examples.h"#include "annie/listeners.h"#include "annie/MultiLayerNetwork.h"using namespace annie;using namespace std;static NumberParameter npars[] = { { "hidden1", "count of hidden neurons (0 = no first hidden layer)", 4. }, { "hidden2", "count of hidden neurons in second hidden layer (0=no layer)", 0. }, DEFAULT_MLP_NPARS { "randomExamples", "number of randomly sampled i-o pairs", 1000. }, { "closeExamples", "number of i-o pairs with low error", 1000. }, { "trainabilityTest", "try to train a conventional MLP first", 0. }, { "verbose", "", 0. }, { NULL } };static StringParameter spars[] = { { "gren", "file to save the trained GREN to", "../data/fromPolar.gren.net" }, { "inputs", "examples to be trained by the GREN (also contains correct outputs for verification)", "../data/fromPolar.inputs.ts" }, { "usedExamples", "file to save the examples that have been used to", "../data/fromPolar.used.ts" }, { "grenTmp", "temporary file for GREN (saved every 30s during training)", "../data/fromPolar.gren.net~" }, { "resume", "resume trainig by loading this network file", "" }, { "resumeConv", "resume conv MLP trainig by loading this network file", "" }, { NULL } };enum { GREN, INPUTS, USED_EXAMPLES, TEMP_NET, RESUME, RESUME_CONV };const real MIN = 0, MAX = 1; //make sure that the neuron's output function you use fits these numbers..enum { DIMI=2, DIMO=1, DIME=1 };struct ExampleGenerator { //TODO: put DIMs here virtual Vector correctOutput(const Vector &i) const =0; virtual Vector calculateError(const Vector &i) const =0;};static real absSumDistance(const Vector &v1, const Vector &v2) { Vector diff = v1 - v2; real sum=0; for(uint i=0; i<diff.size(); i++) sum += abs(diff[i]); return sum;}/// calculates error - the lowest error is 0struct CommonExampleGenerator : public ExampleGenerator { virtual Vector calculateError(const Vector &i) const { Vector r(DIME); Vector in = i.subset(0, DIMI); Vector out = i.subset(DIMI, DIMO); Vector co = correctOutput(in); if(DIMO == 1) { real diff = co[0] - out[0]; //r[0] *= r[0]; //r[0] = abs(diff); //r[0] = 1 / (exp(diff * diff)); r[0] = 1 / (1 + exp(diff)); } else r[0] = absSumDistance(out, co); //r[0] = co.distance(i.subset(DIM, DIM)); ASSERT(r[0] >= MIN && r[0] <= MAX); return r; } };/** the function trained: 2D fromPolar transformation (r,f) --> (x,y) x = r * cos(f) y = r * sin(f) where f \in [0, 1], but it's rescaled to pi/2*/struct FromPolar : public CommonExampleGenerator { virtual Vector correctOutput(const Vector &i) const { ASSERT(DIMI == 2); ASSERT(DIMO <= 2); Vector r(DIMO); r[0] = i[0] * cos (i[1] * M_PI / 2); if(DIMO == 2) { r[1] = i[1] * sin (i[0] * M_PI / 2); } return r; }};///very hard indeedstruct Constant : public CommonExampleGenerator { virtual Vector correctOutput(const Vector &i) const { ASSERT(DIMI == 1 && DIMO == 1); Vector r(DIMO); r[0] = 0.5; return r; }};struct Arith : public CommonExampleGenerator { virtual Vector correctOutput(const Vector &i) const { ASSERT(DIMI == 2); ASSERT(DIMO == 1); Vector r(DIMO); //r[0] = i[0]/2 - i[1]/4; real a = i[0]; real b = i[1]/2; r[0] = (a*a - b + a*b + 0.5)/2; //r[0] = (a - b + 0.5)/2; return r; }};struct Trigon : public CommonExampleGenerator { virtual Vector correctOutput(const Vector &i) const { ASSERT(DIMI == 1); ASSERT(DIMI == 1); Vector r(DIMO); r[0] = cos(i[0]); return r; }};/// Generates TrainingSet using generatorsstruct TSGenerator { TSGenerator(ExampleGenerator &eg) : _eg(eg) {} /** * Replace the o- parts by correct outputs */ TrainingSet correct(TrainingSet &ios) { return ios.mixedXform(Corrector(_eg)); } /** * From TS consisting of (i, o) pairs * * make GREN-TS consisting of ([i, o], err) pairs */ TrainingSet attachError(TrainingSet &ios) { TrainingSet ts(DIMI + DIMO, 1); ios.initialize(); Vector i(DIMI), o(DIMO), inout(DIMI+DIMO); while(!ios.epochOver()) { ios.getNextPair(i, o); inout.setMore(i, 0); inout.setMore(o, DIMI); ts.addIOpair(inout, _eg.calculateError(inout)); } return ts; } /// make examples with output derived from correct output /// @param distort distortion to add TrainingSet close(uint count, real distort) { TrainingSet ts(DIMI, DIMO); Vector in(DIMI), out(DIMO); for(uint i=0; i<count; i++) { real minD = (MAX-MIN)/2 - distort, maxD = (MAX-MIN)/2 + distort; _makeCloseExample(in, out, uniformRandomVector(minD, maxD, DIMI)); ts.addIOpair(in, out); } return ts; } /// make examples with random input and output TrainingSet random(uint count) { TrainingSet ts(DIMI, DIMO); Vector in(DIMI), out(DIMO); for(uint i=0; i<count; i++) { in = uniformRandomVector(MIN, MAX, DIMI); out = uniformRandomVector(MIN, MAX, DIMO); ts.addIOpair(in, out); } return ts; }protected: struct Corrector : TSTransformer { Corrector(ExampleGenerator &eg) : TSTransformer(DIMI, DIMO), _eg(eg) {} virtual void xform(const Vector &in1, const Vector &out1, Vector &in2, Vector &out2) const { in2 = in1; out2 = _eg.correctOutput(in1); } protected: ExampleGenerator &_eg; }; void _makeCloseExample(Vector &in, Vector &out, const Vector &distort) { in = uniformRandomVector(MIN, MAX, DIMI); out = _eg.correctOutput(in); ASSERT(out[0] >= MIN); ASSERT(out[0] <= MAX); out += distort; out.clamp(MIN, MAX); } ExampleGenerator &_eg;};class TrainGREN : public ValueUpdateListener {public: TrainGREN(TrainingSet &ts, uint hidden1, uint hidden2, string tempFile="") : net(DIMI + DIMO), _ts(ts), _control(defaultControl), _tempSaveFile(tempFile) { ASSERT(!hidden2 || hidden1); if(hidden1) net.addLayer(hidden1); if(hidden2) net.addLayer(hidden2); net.addLayer(DIME); //out //full connection net.connectLayer(0); if(hidden1) net.connectLayer(1); if(hidden2) net.connectLayer(2); } TrainGREN(TrainingSet &examples, const std::string &initialNetFile, std::string tempFile="") : net(initialNetFile), _ts(examples), _control(defaultControl), _tempSaveFile(tempFile) { ASSERT(net.getInputCount() == DIMI + DIMO); ASSERT(net.getOutputCount() == DIME); } void train() { if(_tempSaveFile != "") _control.addListener(*this); net.getError(_ts); Shuffler shuffler(_ts, defaultControl.init("_ts shuffle period", 100)); defaultControl.addListener(&shuffler); net.train(_ts, defaultControl); defaultControl.removeListener(&shuffler); if(_tempSaveFile != "") _control.removeListener(*this); cout << "resulting error: " << _control["epoch error"] << ", normalized: " << _control["normalized epoch error"] << endl; } /// get used inputs /*TrainingSet getInputs() { ASSERT(ts); TrainingSet is = ts->mixedXform(Shrinker(DIMI, 0)); ASSERT(is.getInputSize() == DIMI); ASSERT(!is.getOutputSize()); ASSERT(is.getSize() == ts->getSize()); return is; }*/ void save(const string &file) { net.save(file); } operator std::string () const { return net; }protected: virtual void valueChanged(const Value &val) { if(val.name() != "epoch") return; Creal save = _control.get("tmpSaves"); if(save.timeFromLastChange() > 30) { net.save(_tempSaveFile); ++save; cout << "tmp save\n"; } } MultiLayerNetwork net; TrainingSet _ts; PublicValues &_control; std::string _tempSaveFile;};/* void trainConventional(TrainingSet &conv) { Vector z(DIMI), inout(DIMI + DIMO); z.setAll(0); for(uint i=0; i<defaultControl["closeExamples"]; i++) { _makeCloseExample(inout, z); conv.addIOpair(inout.subset(0, DIMI), inout.subset(DIMI, DIMO)); } cout << (string) conv << endl; MultiLayerNetwork * convNet; if(spars[RESUME_CONV].value == "") { convNet = new MultiLayerNetwork (DIMI); uint hidden1 = defaultControl["hidden1"]; uint hidden2 = defaultControl["hidden2"]; if(defaultControl["hidden1"]) convNet->addLayer(defaultControl["hidden1"]); if(defaultControl["hidden2"]) convNet->addLayer(defaultControl["hidden2"]); convNet->addLayer(DIMO); //out //full connection convNet->connectLayer(0); if(hidden1) convNet->connectLayer(1); if(hidden2) convNet->connectLayer(2); } else convNet = new MultiLayerNetwork (spars[RESUME_CONV].value); cout << "conventional MLP " << (string) *convNet << endl; convNet->getError(conv); cout << "training >>>>>>>> " << endl; convNet->train(conv, defaultControl.get("epochs")); convNet->save(spars[GREN].value); delete(convNet); }*/TrainingSet makeExamples() { //FromPolar gener; //Constant gener; Arith gener; //Trigon gener; TSGenerator tsg(gener); //for uniformly distributed inputs, we create //a. completely random outputs //b. slightly distorted correct outputs TrainingSet io = tsg.random(defaultControl["randomExamples"]); io += tsg.close(defaultControl["closeExamples"], (MAX-MIN)/4);#define HYPE_CORRECT_PAIRS#ifdef HYPE_CORRECT_PAIRS TrainingSet correct = tsg.correct(io); io += correct;#endif TrainingSet correctIO = tsg.correct(io); cout << " Saving i/o examples " << (string)correctIO << " to " << spars[INPUTS].value << endl; correctIO.save(spars[INPUTS].value); return tsg.attachError(io);}//// init/run stuffint main(int argc, char *argv[]) { try { sysInit(); parseArgs(argc, argv, npars, spars); SimpleVisualiser visualiser; defaultControl.addListener(visualiser); cout << "preparing" << endl; TrainGREN *gren; TrainingSet ts = makeExamples(); if(spars[USED_EXAMPLES].value != "") ts.save(spars[USED_EXAMPLES].value); if(spars[RESUME].value == "") gren = new TrainGREN(ts, defaultControl["hidden1"], defaultControl["hidden2"], spars[TEMP_NET].value); else { cout << "resuming trainig from file " << spars[RESUME].value << endl; gren = new TrainGREN(ts, spars[RESUME].value, spars[TEMP_NET].value); } if(defaultControl["trainabilityTest"]) { cout << "trainabilityTest" << endl; cout << "TODO..." << endl; exit(0); //trainConventional() } cout << "training " << endl; gren->train(); cout << "saving net " << (string) *gren << " to " << spars[GREN].value; gren->save(spars[GREN].value); // Save inputs together with correct outputs. The former are to be used in trainig where the latter are given so that results can be tested delete(gren); ///print the used values <once again.. defaultControl.triggerAll(); } catch (Exception &e) { cout << ">>> heh?\n" << e.what(); exit(-1); } catch (...) { cout << ">>> heh? unknown exception\n"; exit(-1); } return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -