⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 simpleteacher.java

📁 神经网络命令行工具
💻 JAVA
字号:
package com.weighscore.neuro.plugins;


import com.weighscore.neuro.*;

/**
 * The simple teacher with momentum and constant learn rate
 *
 * @author Fyodor Kravchenko
 * @version 1.0
 */
public class SimpleTeacher extends Teacher {
    /**
     * Momentum coefficient - a multiplier to the last weight changes
     */
    public double momentumCoefficient = 0.2;
    /**
     * Learn rate - a multiplier to the computed teaching values
     */
    public double learnRate = 0.4;

    /**
     * Compute gradient members, compute the weights
     * corrections using the learn rate, the correction of previous iteration
     * and the momentum
     *
     * @param error double[]
     * @param signal The Signal object holding the querying status and the
     *   gradient members
     */
    public void teach(double[] error, Signal signal) {
        // get gradient which is an array of arrays
        double[][] correction = signal.goBack(error);

        NeuralNetworkLocal na = super.getNeuralNetwork();
        // get all network's neurons
        Neuron[] n = na.getNeurons();

        double mom = 0.0;
        // iterate neurons
        for(int i=0; i<correction.length; i++){
            // iterate neuron's weights (threshold and input synapses)
            for(int j = 0; j<correction[i].length; j++){
                // the previous weights change is kept in the neurons statistic
                // object of the MomentumStatistic class
                mom = ((MomentumStatistic) n[i].getMultiStats().getStatistic("MomentumStatistic"))
                      .lastCorrection * this.momentumCoefficient;
                // correction - the negative gradient multiplied on learn rate
                correction[i][j] = correction[i][j] * -1 * this.learnRate + mom;
            }
        }

        // record the correction - teaching
        for (int i = 0; i < na.getNeurons().length; i++) {
            // calling neuron's teach method to update weights
            // ensures that the statistic will be updated
            na.teachNeuron(n[i], correction[i]);
            //n[i].teach(correction[i]);
        }
    }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -