📄 net.java
字号:
package bp;
import java.io.Serializable;
public class Net implements Serializable{
private Layer[] layers;
private double learningRate;
private static int trainingMode = 0; //train patterns of current epoch sequential, backwards or random
public Net(Layer[] layers, double learningRate) {
this.learningRate = learningRate;
this.layers = new Layer[layers.length];
for (int i = 0; i < layers.length; ++i) {
this.layers[i] = (layers[i] != null) ? new Layer(layers[i]) : null;
}
}
public double trainEpoch(Pattern[] patterns){
double maxErr = Double.MIN_VALUE; //double 的最小值
double trainErr = 0;
if (trainingMode++ == 0) { //sequential
for (int i = 0; i < patterns.length; ++i){
System.out.println("pattern mode0 "+patterns[i].toString());
trainErr = trainPattern (patterns[i]);
System.out.println("mode0 trainerror is "+trainErr);
if (trainErr > maxErr)
maxErr = trainErr;
}
}
if(trainingMode++ == 1) { //backwards sequential
for (int i = 0; i < patterns.length; ++i){
System.out.println("pattern mode1 "+patterns[i].toString());
trainErr = trainPattern (patterns[i]);
System.out.println("mode1 trainerror is "+trainErr);
if (trainErr > maxErr)
maxErr = trainErr;
}
}
else{ //random
for (int i = 0; i < patterns.length; ++i) {
System.out.println("pattern mode2 "+patterns[i].toString());
trainErr = trainPattern (patterns[(int) (Math.random() * patterns.length)]);
System.out.println("mode2 trainerror is "+trainErr);
if (trainErr > maxErr)
maxErr = trainErr;
}
}
return maxErr;
}
public double trainPattern(Pattern p){
return trainPattern (p.getIn(), p.getOut());
}
public double trainPattern(double[] in, double[] out){
forwardPhase (in);
backwardPhase (out);
return calculateError();
}
public double[] testPattern(Pattern p){
return testPattern (p.getIn());
}
public double[] testPattern(double[] in){
forwardPhase (in);
double[] out = new double[layers[(layers.length-1)].getNoOfNeurons()];
for(int i = 0; i < out.length; ++i)
out[i] = layers[(layers.length - 1)].getNeuron(i).getOutput();
return out;
}
private void forwardPhase(double[] in){
//???1: apply input vector to input layer
for (int i = 0; i < in.length; ++i)
layers[0].getNeuron (i).setOutput (in[i]);
for(int l = 1; l < layers.length; ++l){
//???2: calculate net
layers[l].calculateNets (layers[l - 1]);
//???3: calculate outputs
for(int i = 0; i < layers[l].getNoOfNeurons(); ++i)
layers[l].getNeuron (i).setOutput (layers[l].getNeuron (i).activationFunction (layers[l].getNeuron (i).getNet()) );
}
}
private void backwardPhase(double[] out){
//???6: calculate errors for each output unit
for (int i = 0; i < layers[layers.length - 1].getNoOfNeurons(); ++i)
layers[layers.length - 1].getNeuron (i).calculateError (out[i]);
//???7: calculate errors for hidden layers
for (int l = layers.length - 2; l > 0; --l){
//boolean removeBias = false; //(l==layers.length-2)?true:false; //if nextLayer != outputlayer, remove biasneuron from calculation
for (int i = 0; i < layers[l].getNoOfNeurons(); ++i)
layers[l].getNeuron (i).calculateError (layers[l + 1], i);
}
//???8: update weights in output layer
for (int l = 1; l < layers.length; ++l) {
int neurons = layers[l].getNoOfNeurons(); //l==(layers.length-1)?layers[l].getNoOfNeurons():layers[l].getNoOfNeurons()-1;
for (int i = 0; i < neurons; ++i)
layers[l].getNeuron (i).updateWeights (learningRate, layers[l - 1]);
}
}
public double calculateError(){
double err = 0.0;
for (int i = 0; i < layers[layers.length - 1].getNoOfNeurons(); ++i){
err += Math.pow (layers[layers.length - 1].getNeuron (i).getError(), 2);
}
return (0.5 * err);
}
public void setLearningRate(double learningRate){
this.learningRate = learningRate;
}
public double getLearningRate(){
return learningRate;
}
public Layer[] getLayers(){
return layers;
}
public String toString(){
String out = "NET\n";
for (int i = 0; i < layers.length; ++i)
out += "\n\tLayer " + i + ": " + layers[i].getNoOfNeurons() + " neurons (each " + layers[i].getNeuron(0).getNoOfWeights() + " weights)\n" + layers[i].toString();
return out;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -