📄 pdeltatrainer.cpp
字号:
/*************************************************************************** pdeltatrainer.cpp - description ------------------- begin : Thu Jul 7 2005 copyright : (C) 2005 by Matt Grover email : mgrover@amygdala.org ***************************************************************************//*************************************************************************** * * * This program is free software; you can redistribute it and/or modify * * it under the terms of the GNU General Public License as published by * * the Free Software Foundation; either version 2 of the License, or * * (at your option) any later version. * * * ***************************************************************************/#include <map>#include "pdeltatrainer.h"#include "utilities.h"#include "synapse.h"using namespace Amygdala;using namespace Utilities;using namespace std;PDeltaTrainerProperties::PDeltaTrainerProperties(): TrainerProperties(), eta(0.1), gamma(0.01), outPeriod(30000), outputMode(PDBINARY){}PDeltaTrainerProperties::PDeltaTrainerProperties(float learningRate, float margin, AmTimeInt outputPeriod, PDOutputMode mode): TrainerProperties(), eta(learningRate), gamma(margin), outPeriod(outputPeriod), outputMode(mode){}void PDeltaTrainerProperties::SetProperty(const std::string& name, const std::string& value){ if (name == "eta") { eta = atof(value.c_str()); } else if (name == "gamma") { gamma = atof(value.c_str()); } else if (name == "outputMode") { outputMode = (PDOutputMode)(atoi(value.c_str())); } else { //TrainerProperties::SetProperty(name, value); }}std::map< std::string, std::string > PDeltaTrainerProperties::GetPropertyMap() const{ std::map< std::string, std::string > props = TrainerProperties::GetPropertyMap(); props["eta"] = ftostr(eta); props["gamma"] = ftostr(gamma); props["outputMode"] = itostr(outputMode); return props;}PDeltaTrainerProperties* PDeltaTrainerProperties::Copy(){ return new PDeltaTrainerProperties(*this);}PDeltaTrainer::PDeltaTrainer(PDeltaTrainerProperties& props, std::string name): StaticTrainer(props, name), eta(props.GetEta()), gamma(props.GetGamma()), outPeriod(props.GetOutputPeriod()), currentTrainingStep(0), totalSpikeCount(0), outputMode(props.GetOutputMode()){ Network::GetNetworkRef()->SetTrainerCallback(this, Network::GetNetworkRef()->SimTime()+outPeriod);}PDeltaTrainer::~PDeltaTrainer(){}void PDeltaTrainer::Train(StaticSynapse* syn, AmTimeInt lastTransmitTime, unsigned int lastHistIdx){ // Try to find match in synapseActivity and increment the value. If a match is not found, add the synapse // to the PDeltaNeuron.synapses vector and then add the synapse to synapseActivity and initialize the value to 1 map< StaticSynapse*, unsigned int >::iterator synItr = synapseActivity.find(syn); if (synItr == synapseActivity.end()) { //cout << "Adding synapse to trainer\n"; neurons[ syn->GetDendrite()->GetPostNeuron() ].synapses.push_back(syn); synapseActivity[syn] = 1; } else { //cout << "Incrementing synapse activity\n"; synItr->second++; }}void PDeltaTrainer::ReportSpike(SpikingNeuron* nrn){ neurons[nrn].activity++; ++totalSpikeCount;}float PDeltaTrainer::CalcError(){ float err=0.; float out=0.; const TrainingExample& te = trainingExamples.top(); if (te.step == currentTrainingStep) { float activeCount = 0; map< SpikingNeuron*, PDeltaNeuron >::iterator itr; // calculate error depending on output mode switch (outputMode) { case PDBINARY: /*if ((float)totalSpikeCount > (float)(neurons.size())/2.) { out=1.; }*/ for (itr=neurons.begin(); itr!=neurons.end(); itr++) { if (itr->second.activity>0) { activeCount += 1.; } } out = activeCount/((float)(neurons.size())); err=te.value-out; if (out > 0.5) { out = 1; if (te.value == 1) { err = 0; } } else { out = 0; if (te.value == 0) { err=0; } } //cout << "error = " << te.value << " - " << out << endl; cout << te.value << "\t" << out << "\t" << err << endl; break; case PDSCALED: for (itr=neurons.begin(); itr!=neurons.end(); itr++) { if (itr->second.activity>0) { activeCount += 1.; } } out = activeCount/((float)(neurons.size())); err=te.value-out; if (abs(err) < 0.05) err = 0; cout << te.value << "\t" << out << "\t" << err << "\t" << totalSpikeCount << endl; break; case PDRATE: err = (te.value - (float)totalSpikeCount)/30; cout << te.value << "\t" << totalSpikeCount << "\t" << err << endl; break; } trainingExamples.pop(); } return err;}void PDeltaTrainer::PeriodicTrain(){ //cout << "Training!\n"; // Get the error level for this step float err = CalcError(); //cout << "Error " << err << endl; // TESTING CODE! // add a small value to each synapse for a test map< SpikingNeuron*, PDeltaNeuron >::iterator itr; for (itr=neurons.begin(); itr!=neurons.end(); itr++) { bool t=false; // train only the neurons that contributed to the error if (err>0.) { if (itr->second.activity>0) { t=false; } else { t=true; } } else if (err<0.) { if (itr->second.activity>0) { t=true; } else { t=false; } } if (outputMode == PDRATE && err != 0) { t = true; } //cout << "Iterating neurons\n"; if (t) { for (unsigned int i=0; i<itr->second.synapses.size(); ++i) { StaticSynapse* syn = itr->second.synapses[i]; if (synapseActivity[syn] > 0) { // for now, apply a simple perceptron learning rule to active synapses float w = syn->GetWeight(); w += err*eta; //w += eta; if (w<0.95 && w>-0.95) { // keep weights bounded syn->SetWeight(w); } } } } itr->second.activity = 0; } // Reset all of the synapseActivity values to 0 map< StaticSynapse*, unsigned int >::iterator synItr; for (synItr=synapseActivity.begin(); synItr!=synapseActivity.end(); synItr++) { synItr->second = 0; } // schedule the next training time if (outPeriod) { Network::GetNetworkRef()->SetTrainerCallback(this, Network::GetNetworkRef()->SimTime()+outPeriod); } totalSpikeCount = 0; ++currentTrainingStep;}void PDeltaTrainer::AddTrainingVector(std::vector< TrainingExample >& te){ for (unsigned int i=0; i<te.size(); ++i) { trainingExamples.push(te[i]); }}void PDeltaTrainer::PrintWeights(){ map< StaticSynapse*, unsigned int >::iterator synItr; for (synItr=synapseActivity.begin(); synItr!=synapseActivity.end(); synItr++) { cout << synItr->first->GetWeight() << endl; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -