⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 net.java

📁 问使用java程序实现bp算法的问题!!我想对大家是有用的请看看
💻 JAVA
字号:
package bp;

import java.io.Serializable;

public class Net implements Serializable{
    private Layer[] layers;
    private double learningRate;
    
    private static int trainingMode = 0; //train patterns of current epoch sequential, backwards or random
    
    
    public Net(Layer[] layers, double learningRate) {
      this.learningRate = learningRate;
      this.layers = new Layer[layers.length];
      for (int i = 0; i < layers.length; ++i) {
            this.layers[i] = (layers[i] != null) ? new Layer(layers[i]) : null; 
      }
    }
    
    public double trainEpoch(Pattern[] patterns){
        double maxErr = Double.MIN_VALUE; //double 的最小值
        double trainErr  = 0;
        if (trainingMode++ == 0) {                      //sequential
            for (int i = 0; i < patterns.length; ++i){
            	System.out.println("pattern mode0 "+patterns[i].toString());
                trainErr = trainPattern (patterns[i]);
                System.out.println("mode0 trainerror is "+trainErr);
                if (trainErr > maxErr)
                    maxErr = trainErr;
            }
        }
        if(trainingMode++ == 1) {                      //backwards sequential
            for (int i = 0; i < patterns.length; ++i){
            	System.out.println("pattern mode1 "+patterns[i].toString());
                trainErr = trainPattern (patterns[i]);
                System.out.println("mode1 trainerror is "+trainErr);
                if (trainErr > maxErr)
                    maxErr = trainErr;
            }
        }
        else{                                       //random
            for (int i = 0; i < patterns.length; ++i) {
            	System.out.println("pattern mode2 "+patterns[i].toString());
                trainErr = trainPattern (patterns[(int) (Math.random() * patterns.length)]);
                System.out.println("mode2 trainerror is "+trainErr);
                if (trainErr > maxErr)
                    maxErr = trainErr;
            }
        }
        return maxErr;
        
           
    }
    
    public double trainPattern(Pattern p){
        return trainPattern (p.getIn(), p.getOut());
    }
    
    public double trainPattern(double[] in, double[] out){
        forwardPhase (in);
        backwardPhase (out);
        return calculateError();
    }
    
    public double[] testPattern(Pattern p){
        return testPattern (p.getIn());
    }
    public double[] testPattern(double[] in){
        forwardPhase (in);
        double[] out = new double[layers[(layers.length-1)].getNoOfNeurons()];
        for(int i = 0; i < out.length; ++i)
            out[i] = layers[(layers.length - 1)].getNeuron(i).getOutput();
        return out;
    }
    
    private void forwardPhase(double[] in){
        //???1: apply input vector to input layer
        for (int i = 0; i < in.length; ++i)
            layers[0].getNeuron (i).setOutput (in[i]);
        
        for(int l = 1; l < layers.length; ++l){
            
            //???2: calculate net 
            layers[l].calculateNets (layers[l - 1]);
        
            //???3: calculate outputs
            for(int i = 0; i < layers[l].getNoOfNeurons(); ++i)
                layers[l].getNeuron (i).setOutput (layers[l].getNeuron (i).activationFunction (layers[l].getNeuron (i).getNet()) );
        }
        
    }
    
    private void backwardPhase(double[] out){
        //???6: calculate errors for each output unit
        for (int i = 0; i < layers[layers.length - 1].getNoOfNeurons(); ++i)
            layers[layers.length - 1].getNeuron (i).calculateError (out[i]);
        
        //???7: calculate errors for hidden layers
        for (int l = layers.length - 2; l > 0; --l){     
            //boolean removeBias = false; //(l==layers.length-2)?true:false;   //if nextLayer != outputlayer, remove biasneuron from calculation
            for (int i = 0; i < layers[l].getNoOfNeurons(); ++i)
                layers[l].getNeuron (i).calculateError (layers[l + 1], i);
        }
        
        //???8: update weights in output layer
        for (int l = 1; l < layers.length; ++l) {
            int neurons = layers[l].getNoOfNeurons(); //l==(layers.length-1)?layers[l].getNoOfNeurons():layers[l].getNoOfNeurons()-1;
            for (int i = 0; i < neurons; ++i)
                layers[l].getNeuron (i).updateWeights (learningRate, layers[l - 1]);
        }
    }
    
    public double calculateError(){
        double err = 0.0;
        for (int i = 0; i < layers[layers.length - 1].getNoOfNeurons(); ++i){
            err += Math.pow (layers[layers.length - 1].getNeuron (i).getError(), 2);
        }
        return (0.5 * err);
    }
    
    public void setLearningRate(double learningRate){
        this.learningRate = learningRate;
    }
    
    public double getLearningRate(){
        return learningRate;
    }
    public Layer[] getLayers(){
        return layers;
    }
    
    public String toString(){
        String out = "NET\n";
        for (int i = 0; i < layers.length; ++i)
            out += "\n\tLayer " + i + ": " + layers[i].getNoOfNeurons() + " neurons (each " + layers[i].getNeuron(0).getNoOfWeights() + " weights)\n" + layers[i].toString();
        return out;
    } 
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -