📄 simpleteacher.java
字号:
package com.weighscore.neuro.plugins;
import com.weighscore.neuro.*;
/**
* The simple teacher with momentum and constant learn rate
*
* @author Fyodor Kravchenko
* @version 1.0
*/
public class SimpleTeacher extends Teacher {
/**
* Momentum coefficient - a multiplier to the last weight changes
*/
public double momentumCoefficient = 0.2;
/**
* Learn rate - a multiplier to the computed teaching values
*/
public double learnRate = 0.4;
/**
* Compute gradient members, compute the weights
* corrections using the learn rate, the correction of previous iteration
* and the momentum
*
* @param error double[]
* @param signal The Signal object holding the querying status and the
* gradient members
*/
public void teach(double[] error, Signal signal) {
// get gradient which is an array of arrays
double[][] correction = signal.goBack(error);
NeuralNetworkLocal na = super.getNeuralNetwork();
// get all network's neurons
Neuron[] n = na.getNeurons();
double mom = 0.0;
// iterate neurons
for(int i=0; i<correction.length; i++){
// iterate neuron's weights (threshold and input synapses)
for(int j = 0; j<correction[i].length; j++){
// the previous weights change is kept in the neurons statistic
// object of the MomentumStatistic class
mom = ((MomentumStatistic) n[i].getMultiStats().getStatistic("MomentumStatistic"))
.lastCorrection * this.momentumCoefficient;
// correction - the negative gradient multiplied on learn rate
correction[i][j] = correction[i][j] * -1 * this.learnRate + mom;
}
}
// record the correction - teaching
for (int i = 0; i < na.getNeurons().length; i++) {
// calling neuron's teach method to update weights
// ensures that the statistic will be updated
na.teachNeuron(n[i], correction[i]);
//n[i].teach(correction[i]);
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -