📄 statictrainer.cpp
字号:
/*************************************************************************** statictrainer.cpp ------------------- begin : Sat Mar 19 2005 copyright : (C) 2005 by Matt Grover email : mgrover@amygdala.org ***************************************************************************//*************************************************************************** * * * This program is free software; you can redistribute it and/or modify * * it under the terms of the GNU General Public License as published by * * the Free Software Foundation; either version 2 of the License, or * * (at your option) any later version. * * * ***************************************************************************/#include "statictrainer.h"#include "synapse.h"#include "utilities.h"#include "network.h"using namespace Amygdala;using namespace Utilities;using namespace std;StaticHebbianTrainerProperties::StaticHebbianTrainerProperties(): TrainerProperties(), synTimeConst(6.), synPotConst(0.4), synDepConst(0.25), learningMax(-0.5), posLearnTimeConst(0.5), negLearnTimeConst(5.0), learningConst(1e-2), historySize(20){}StaticHebbianTrainerProperties::StaticHebbianTrainerProperties(float synapticTimeConst, float synapticPotentiationConst, float synapticDepressionConst, float learningMaximum, float posLearningTimeConst, float negLearningTimeConst, float learningConstant, unsigned int histSize): TrainerProperties(), synTimeConst(synapticTimeConst), synPotConst(synapticPotentiationConst), synDepConst(synapticDepressionConst), learningMax(learningMaximum), posLearnTimeConst(posLearningTimeConst), negLearnTimeConst(negLearningTimeConst), learningConst(learningConstant), historySize(histSize){}void StaticHebbianTrainerProperties::SetProperty(const std::string& name, const std::string& value){ if (name == "learningConst") { learningConst = atof(value.c_str()); } else if (name == "synTimeConst") { synTimeConst = atof(value.c_str()); } else if (name == "synPotentiationConst") { synPotConst = atof(value.c_str()); } else if (name == "synDepressionConst") { synDepConst = atof(value.c_str()); } else if (name == "learningMax") { learningMax = atof(value.c_str()); } else if (name == "posLearnTimeConst") { posLearnTimeConst = atof(value.c_str()); } else if (name == "negLearnTimeConst") { negLearnTimeConst = atof(value.c_str()); } else if (name == "historySize") { historySize = atoi(value.c_str()); } else { //TrainerProperties::SetProperty(name, value); }}std::map< std::string, std::string > StaticHebbianTrainerProperties::GetPropertyMap() const{ std::map< std::string, std::string > props = TrainerProperties::GetPropertyMap(); props["learningConst"] = ftostr(learningConst); props["synPotentiationConst"] = ftostr(synPotConst); props["synDepressionConst"] = ftostr(synDepConst); props["learningMax"] = ftostr(learningMax); props["posLearnTimeConst"] = ftostr(posLearnTimeConst); props["negLearnTimeConst"] = ftostr(negLearnTimeConst); props["historySize"] = itostr(historySize); return props;}StaticHebbianTrainerProperties* StaticHebbianTrainerProperties::Copy(){ return new StaticHebbianTrainerProperties(*this);}StaticTrainer::~StaticTrainer(){}TrainerProperties* StaticTrainer::Properties(){ return 0;}StaticHebbianTrainer::StaticHebbianTrainer(StaticHebbianTrainerProperties& props, std::string name): StaticTrainer(props, name), learningMax(props.GetLearningMax()), histIdx(0), lastSpikeTime(0), windowWidth(100), weightDiffPreLookup(0), weightDiffPostLookup(0){ spikeHistory.resize(props.GetHistorySize()); InitLookup();}StaticHebbianTrainer::~StaticHebbianTrainer(){}void StaticHebbianTrainer::Train(StaticSynapse* syn, AmTimeInt lastTransmitTime, unsigned int lastHistIdx){ float timeDiff = (float)lastTransmitTime - (float)spikeHistory[lastHistIdx]; // find the lookup table index based on timeDiff and check to see if it is // within the acceptable range if (timeDiff <= learningMax) { // learningMax is always negative, so it is safe to say that idx will always // be positive here. unsigned int idx = (unsigned int)((learningMax - timeDiff)*0.01); if (idx > windowWidth) { // index out of range -- don't do anything return; } float weightDiff = weightDiffPreLookup[idx]; float weight = syn->GetWeight(); syn->SetWeight(weight+weightDiff); } else { unsigned int idx = (unsigned int)((timeDiff - learningMax)*0.01); if (idx > windowWidth) { return; } float weightDiff = weightDiffPostLookup[idx]; float weight = syn->GetWeight(); syn->SetWeight(weight+weightDiff); }}void StaticHebbianTrainer::ReportSpike(SpikingNeuron* nrn){ lastSpikeTime = Network::GetNetworkRef()->SimTime(); spikeHistory[histIdx++] = lastSpikeTime; if (histIdx >= spikeHistory.size()) { histIdx = 0; }}void StaticHebbianTrainer::InitLookup(){ FunctionLookup* fl = Network::GetNetworkRef()->GetFunctionLookup(); try { TableProperties t0 = GetTableProps(0); TableProperties t1 = GetTableProps(1); weightDiffPreLookup = fl->GetTableData(t0); weightDiffPostLookup = fl->GetTableData(t1); } catch (TableNotFoundException& e) { MakeLookupTables(); }}TableProperties StaticHebbianTrainer::GetTableProps(unsigned int index){ TableProperties props; props.SetClassName("StaticHebbianTrainer"); props.SetTableSize(windowWidth); StaticHebbianTrainerProperties* shProps = dynamic_cast<StaticHebbianTrainerProperties*>(tprops); props.AddParam((AmTimeInt)index); props.AddParam(shProps->GetLearningConstant()); props.AddParam(shProps->GetSynapticTimeConst()); props.AddParam(shProps->GetSynapticPotentiationConst()); props.AddParam(shProps->GetSynapticDepressionConst()); props.AddParam(shProps->GetLearningMax()); props.AddParam(shProps->GetPositiveLearningConst()); props.AddParam(shProps->GetNegativeLearningConst()); return props;}void StaticHebbianTrainer::MakeLookupTables(){ FunctionLookup* fl = Network::GetNetworkRef()->GetFunctionLookup(); TableProperties t0 = GetTableProps(0); TableProperties t1 = GetTableProps(1); LOGGER(6, "Making lookup tables") weightDiffPreLookup = fl->MakeLookupTable(t0); weightDiffPostLookup = fl->MakeLookupTable(t1); LOGGER(6, "Lookups made") StaticHebbianTrainerProperties* shProps = dynamic_cast<StaticHebbianTrainerProperties*>(tprops); // Fill weightDiffPreLookup for spikes that arrive at the postsynaptic // neuron earlier than t=learningMax for (unsigned int i=0; i<windowWidth; ++i) { float lMax = shProps->GetLearningMax()*1000.; float timeDiff = ((float)i*100.0)-lMax; // Assumes time resolution is 0.1ms timeDiff *= -1.; float Apos = shProps->GetSynapticPotentiationConst(); float Aneg = shProps->GetSynapticDepressionConst(); float tauSyn = shProps->GetSynapticTimeConst()*1000.; float expTerm = exp( -(lMax-timeDiff)/tauSyn ); weightDiffPreLookup[i] = (Apos-Aneg)*expTerm; } // Fill weightDiffPostLookup for spikes that arrive at the postsynaptic // neuron later than t=learningMax for (unsigned int i=0; i<windowWidth; ++i) { float lMax = shProps->GetLearningMax()*1000.; float timeDiff = ((float)i*100.0)+lMax; float Apos = shProps->GetSynapticPotentiationConst(); float Aneg = shProps->GetSynapticDepressionConst(); float tauPos = shProps->GetPositiveLearningConst()*1000.; float tauNeg = shProps->GetNegativeLearningConst()*1000.; float expTerm1 = exp( -(timeDiff-learningMax)/tauPos ); float expTerm2 = exp( -(timeDiff-learningMax)/tauNeg ); weightDiffPostLookup[i] = (Apos*expTerm1) - (Aneg*expTerm2); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -