⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 pdeltatrainer.h

📁 amygdata的神经网络算法源代码
💻 H
字号:
/***************************************************************************                          pdeltatrainer.h  -  description                             -------------------    begin                : Thu Jul 7 2005    copyright            : (C) 2005 by Matt Grover    email                : mgrover@amygdala.org ***************************************************************************//*************************************************************************** *                                                                         * *   This program is free software; you can redistribute it and/or modify  * *   it under the terms of the GNU General Public License as published by  * *   the Free Software Foundation; either version 2 of the License, or     * *   (at your option) any later version.                                   * *                                                                         * ***************************************************************************/#ifndef PDELTATRAINER_H#define PDELTATRAINER_H#include <amygdala/statictrainer.h>#include <vector>#include <queue>namespace Amygdala {class StaticSynapse;enum PDOutputMode { PDBINARY, PDSCALED, PDRATE };class PDeltaTrainerProperties: public TrainerProperties{public:    PDeltaTrainerProperties();    PDeltaTrainerProperties(float learningRate, float margin, AmTimeInt outputPeriod, PDOutputMode mode);        virtual ~PDeltaTrainerProperties() {}    virtual PDeltaTrainerProperties* Copy();    virtual void SetProperty(const std::string& name, const std::string& value);    virtual std::map< std::string, std::string > GetPropertyMap() const;    void SetEta(float e) { eta = e; }    float GetEta() const { return eta; }    void SetGamma(float g) { gamma = g; }    float GetGamma() const { return gamma; }    void SetOutputPeriod(AmTimeInt p) { outPeriod = p; }    AmTimeInt GetOutputPeriod() const { return outPeriod; }    void SetOutputMode(PDOutputMode m) { outputMode = m; }    PDOutputMode GetOutputMode() const { return outputMode; }protected:    float eta;      // learning rate    float gamma;    // margin    AmTimeInt outPeriod;    PDOutputMode outputMode;};class PDeltaTrainer : public StaticTrainer {    template<class trainer, class trainerProperties>    friend class TrainerFactory;public:    typedef TrainerFactory<PDeltaTrainer,PDeltaTrainerProperties> Factory;    struct TrainingExample {        unsigned int step;        float value;        bool operator<(const TrainingExample& rhs) const {            return step>rhs.step;        }    };    struct PDeltaNeuron {        unsigned int activity;              // Number of spikes during this cycle        std::vector< StaticSynapse* > synapses;    };    virtual ~PDeltaTrainer();    PDeltaTrainerProperties* Properties() { return static_cast<PDeltaTrainerProperties*>(tprops); }    /** For compatibility with StaticTrainer.  lastHistIdx is not used. */    virtual void Train(StaticSynapse* syn, AmTimeInt lastTransmitTime, unsigned int lastHistIdx);    virtual void ReportSpike(SpikingNeuron* nrn);    /** Called at regular intervals from Network::Run() if supervised training is in use. */    virtual void PeriodicTrain();    void SetOutputPeriod(AmTimeInt p) { outPeriod = p; }    AmTimeInt GetOutputPeriod() const { return outPeriod; }    void AddTrainingExample(TrainingExample& te) { trainingExamples.push(te); }    void AddTrainingVector(std::vector< TrainingExample >& te);    void PrintWeights();protected:	PDeltaTrainer(PDeltaTrainerProperties& props, std::string name);    float CalcError();    float eta;      // learning rate    float gamma;    // weight margin    AmTimeInt outPeriod;    unsigned int currentTrainingStep;    unsigned int totalSpikeCount;   // total spikes from all neurons during current cycle    PDOutputMode outputMode;    std::priority_queue<TrainingExample> trainingExamples;    std::map< SpikingNeuron*, PDeltaNeuron > neurons;       std::map< StaticSynapse*, unsigned int > synapseActivity;     // Number of spikes passed by synapse during this cycle    std::vector< float > outputHistory;};namespace Factory {    static PDeltaTrainer::Factory MakePDeltaTrainer;}}   // namespace Amygdala#endif // PDELTATRAINER_H

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -