📄 net.java
字号:
/* AI ANN Backpropagation Copyright (C) 2002-2003 Wim Gillis <d.e.m@gmx.net> http://sourceforge.net/projects/crap This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program (see COPYING); if not, check out http://www.gnu.org/licenses/gpl.html or write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA*/package cx.ma.ai.ann.backprop;import java.io.Serializable;public class Net implements Serializable{ private Layer[] layers; private double learningRate; private static int trainingMode = 0; //train patterns of current epoch sequential, backwards or random public Net(Layer[] layers, double learningRate) { this.learningRate = learningRate; this.layers = new Layer[layers.length]; for (int i = 0; i < layers.length; ++i) { this.layers[i] = (layers[i] != null) ? new Layer(layers[i]) : null; } } public double trainEpoch(Pattern[] patterns){ double maxErr = Double.MIN_VALUE; double trainErr = 0; if (trainingMode++ == 0) { //sequential for (int i = 0; i < patterns.length; ++i){ trainErr = trainPattern (patterns[i]); if (trainErr > maxErr) maxErr = trainErr; } } if(trainingMode++ == 1) { //backwards sequential for (int i = 0; i < patterns.length; ++i){ trainErr = trainPattern (patterns[i]); if (trainErr > maxErr) maxErr = trainErr; } } else{ //random for (int i = 0; i < patterns.length; ++i) { trainErr = trainPattern (patterns[(int) (Math.random() * patterns.length)]); if (trainErr > maxErr) maxErr = trainErr; } } return maxErr; } public double trainPattern(Pattern p){ return trainPattern (p.getIn(), p.getOut()); } public double trainPattern(double[] in, double[] out){ forwardPhase (in); backwardPhase (out); return calculateError(); } public double[] testPattern(Pattern p){ return testPattern (p.getIn()); } public double[] testPattern(double[] in){ forwardPhase (in); double[] out = new double[layers[(layers.length-1)].getNoOfNeurons()]; for(int i = 0; i < out.length; ++i) out[i] = layers[(layers.length - 1)].getNeuron(i).getOutput(); return out; } private void forwardPhase(double[] in){ //???1: apply input vector to input layer for (int i = 0; i < in.length; ++i) layers[0].getNeuron (i).setOutput (in[i]); for(int l = 1; l < layers.length; ++l){ //???2: calculate net layers[l].calculateNets (layers[l - 1]); //???3: calculate outputs for(int i = 0; i < layers[l].getNoOfNeurons(); ++i) layers[l].getNeuron (i).setOutput (layers[l].getNeuron (i).activationFunction (layers[l].getNeuron (i).getNet()) ); } } private void backwardPhase(double[] out){ //???6: calculate errors for each output unit for (int i = 0; i < layers[layers.length - 1].getNoOfNeurons(); ++i) layers[layers.length - 1].getNeuron (i).calculateError (out[i]); //???7: calculate errors for hidden layers for (int l = layers.length - 2; l > 0; --l){ //boolean removeBias = false; //(l==layers.length-2)?true:false; //if nextLayer != outputlayer, remove biasneuron from calculation for (int i = 0; i < layers[l].getNoOfNeurons(); ++i) layers[l].getNeuron (i).calculateError (layers[l + 1], i); } //???8: update weights in output layer for (int l = 1; l < layers.length; ++l) { int neurons = layers[l].getNoOfNeurons(); //l==(layers.length-1)?layers[l].getNoOfNeurons():layers[l].getNoOfNeurons()-1; for (int i = 0; i < neurons; ++i) layers[l].getNeuron (i).updateWeights (learningRate, layers[l - 1]); } } public double calculateError(){ double err = 0.0; for (int i = 0; i < layers[layers.length - 1].getNoOfNeurons(); ++i){ err += Math.pow (layers[layers.length - 1].getNeuron (i).getError(), 2); } return (0.5 * err); } public void setLearningRate(double learningRate){ this.learningRate = learningRate; } public double getLearningRate(){ return learningRate; } public String toString(){ String out = "NET\n"; for (int i = 0; i < layers.length; ++i) out += "\n\tLayer " + i + ": " + layers[i].getNoOfNeurons() + " neurons (each " + layers[i].getNeuron(0).getNoOfWeights() + " weights)\n" + layers[i].toString(); return out; } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -