⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 backpropagation.java

📁 用java实现的神经网络bp的核心算法
💻 JAVA
字号:
/*
 * BackPropagation.java
 *
 * Created on 2007年11月15日, 上午4:01
 *
 */

package neuralNetwork;
import myUtil.*;
import java.util.*;
import java.io.*;
/**
 * Class for back propagation computing.
 * @author yuhui_bear
 */
public class BackPropagation {
    /**
     * instance of network.
     */
    public NeuralNetwork network;
    /**
     * the array stroed all neurons.
     */
    private Neuron[][] nntA ;
    private int den = 0;
    /**
     * half absolute output range of network.
     */
    private final float a= 1.716F;
    private final float b= 0.6F;
    private int midlay=0 , midsize=0;
    /**
     * Creates a new instance of BackPropagation
     * @param nettop int array stored numbers of each layer.
     * @param branch dendrities of neuron.
     */
    public BackPropagation(int[] nettop , int branch) {
        network = new NeuralNetwork(nettop,branch);
        nntA = network.neuralNet_A;
        den = nntA[0][0].dendrites.length -1;
        midlay = (int)(network.getNetInformation().nettop.length /2);
        midsize = network.getNetInformation().nettop[midlay];
    }
    /**
     * create a back propagation instance ,a neural network included.
     * @param filename this is the neural net file.
     */
    public BackPropagation(File filename){
        try {
            network = new NeuralNetwork(filename);
            nntA = network.neuralNet_A;
            den = nntA[0][0].dendrites.length -1;
        } catch (IOException ex) {
            ex.printStackTrace();
            System.exit(-1);
        }
        midlay = (int)(network.getNetInformation().nettop.length /2) ;
        midsize = network.getNetInformation().nettop[midlay];
    }
//    public Neuron[][]  getNet(){
//        return nntA;
//    }
     /**
     * encode input data .
     * @param indata soure data to be encoded
     * @return encoded data.
     */
    public double[] encode(double[] indata){
        double[] endata = forwardCalculate(indata);
        for (int i = 0 ; i< midsize; i++){
            endata[i] = nntA[midlay][i].output();
        }
        return endata;
    }
    /**
     * get encoded data .
     * @return return encoded data.
     */
    public double[] encode(){
        double[] endata = new double[midsize];
        for (int i = 0 ; i< midsize; i++){
            endata[i] = nntA[midlay][i].output();
        }
        return endata;
    }
    /**
     * decode input data.
     * @param sourceData data to be decoded.
     * @return decoded data.
     */
    public double[] decode(double[] sourceData){
        if( ! (((sourceData.length % den) ==0 ) || ((den % sourceData.length) ==0 ))){
           System.err.append("decode error!!! size of soure data cann't matches dendrities of neuron.");
           System.exit(-1);
        }
        Fetcher linker = new Fetcher(sourceData.length);
        double[] input = new double[den];        
        int[] linkArray ;
        Neuron curNp= null;
        int midnextSize =nntA[midlay+1].length;
        for ( int curn =0;curn < midnextSize; curn ++){
            linkArray =linker.nextGroup(den);
            for ( int k =0; k< den;k++){
                input[k] = sourceData[linkArray[k]];
            }
            nntA[midlay+1][curn].input(input);
        }
        if (midlay +2 < nntA.length){
            for ( int curl = midlay+2; curl<nntA.length ; curl++){  
                for ( int i =0 ; i< nntA[curl].length ;i++){
                    curNp = (Neuron)nntA[curl][i];
                    for(int cp=0;cp<den;cp++){
                        input[cp]=curNp.dendrites[cp+1].in.output();                    
                    }
                    curNp.input(input);
                }                       
            }
        }

        double[] temp = new double[nntA[nntA.length-1].length];
        for (int i=0; i<nntA[nntA.length-1].length;i++){
            temp[i] = nntA[nntA.length-1][i].output();
        }
        return temp;
    }
    /**
     * backward propagate compution.
     * @param aOfw value of momentum.
     * @param fout ouput stream for debug.
     * @param sourceData , input signal of the network.
     * @param outputData , output signal of the network.
     * @param step , length of learning step.
     */
    public  void backwardCalculate(double[] sourceData, double[] outputData ,double step,double aOfw ,PrintWriter fout){
        double grad =0 ;
        Neuron curNp =null;
        // for out put layers: D3 , the last layer.
        double[] dif = Evaluate.difference(sourceData,outputData);
        int lastlay = nntA.length-1;
        for (int i =0;i< nntA[lastlay].length;i++){
            grad =0 ;
            curNp = nntA[lastlay][i];
//            fout.println("bc,"+curNp);      // for testing only
//            System.out.println("bc,"+curNp);
            grad=dif[i] * curNp.output() * (1-curNp.output());
//            grad =dif[i] * (Math.pow(a,2) -Math.pow(curNp.output(),2) )*b/(2*a);
            curNp.grad = grad;
            curNp.setWeight(step * 1.25,aOfw);
//            fout.println("bc,"+curNp);      // for testing only
//            System.out.println("bc,"+curNp);
        }
        
        // for layers : D2 , D1 , LO , L3 , L2 ,L1, 
        for ( int curl =lastlay-1 ; curl >-1;curl--){
            for (int i =0 ;i<nntA[curl].length;i++) {
                //calculate product of sub-neuron's linked weight and grad
                grad=0;
                curNp = nntA[curl][i];
//                fout.println("bc,"+curNp);      // for testing only
//                System.out.println("bc,"+curNp);
                for (int k = 0 ; k< curNp.subDen.length;k++){
//                    grad += curNp.sub.get(k).weight * curNp.sub.get(k).out.grad ;
                    grad += curNp.subDen[k].weight * curNp.subDen[k].out.grad ;
                } 
                // get grad
                grad = grad* curNp.output() * (1-curNp.output());
//                grad =grad * (Math.pow(a,2) -Math.pow(curNp.output(),2) )*b/(2*a);
                curNp.grad = grad;
                curNp.setWeight(step,aOfw);            
//                fout.println("bc,"+curNp);      // for testing only
//                System.out.println("bc,"+curNp);
            }
        }
        
    }
    /**
     * backward propagate compution.
     * @param sourceData , input signal of the network.
     * @param outputData , output signal of the network.
     * @param step , length of learning step.
     * @param aOfw value of momentum.
     */
    public  void backwardCalculate(double[] sourceData, double[] outputData ,double step,double aOfw ){
        double grad =0 ;
        Neuron curNp =null;
        // for out put layers, the output layer.
        double[] dif = Evaluate.difference(sourceData,outputData);
        int lastlay = nntA.length-1;
        for (int i =0;i< nntA[lastlay].length;i++){
            grad =0 ;
            curNp = nntA[lastlay][i];
            grad = dif[i]* curNp.output() * (1-curNp.output());
//            grad =dif[i] * (Math.pow(a,2) - Math.pow(curNp.output(),2) )*b/(2*a);
            curNp.grad = grad;
            curNp.setWeight((step * 1.1) ,aOfw);
        }
        // for layers : D2 , D1 , LO , L3 , L2 ,L1,hideds 
        for ( int curl =lastlay-1 ; curl >-1;curl--){
            for (int i =0 ;i<nntA[curl].length;i++) {
                //calculate product of sub-neuron's linked weight and grad
                grad=0;
                curNp = nntA[curl][i];
                for (int k = 0 ; k< curNp.subDen.length;k++){
//                    grad += curNp.sub.get(k).weight * curNp.sub.get(k).out.grad ;
                    grad += curNp.subDen[k].weight * curNp.subDen[k].out.grad ;
                } 
                // get grad
                grad = grad* curNp.output() * (1-curNp.output());
//                grad =grad * (Math.pow(a,2) -Math.pow(curNp.output(),2) )*b/(2*a);
                curNp.grad = grad;
                curNp.setWeight(step ,aOfw);   
            }
        }
        
    }
    ////////////////////////////////////////////////////////////////////////////////
    /**
     * froward calculate .
     * @param sourceData input data for neural network.
     * @param fout  used to output the process of calculation for debuging.
     * @return an array of elments, which is data decompressed.
     */
    public double[] forwardCalculate(double[] sourceData, PrintWriter fout){
        network.inputAdapte(sourceData);
        Neuron curNp= null;
        double[] input = new double[den];;

        for ( int curl = 0; curl<nntA.length ; curl++){  
            for ( int i =0 ; i< nntA[curl].length ;i++){
                
                curNp = (Neuron)nntA[curl][i];
//                System.out.println("fc,"+curNp);        // for testing only
//                fout.println("fc," +curNp);        // for testing only
                // prepare input vectors.
//                input = new double[den];
                for(int cp=0;cp<den;cp++){
                    input[cp]=curNp.dendrites[cp+1].in.output();                    
                }
                curNp.input(input);
//                fout.println("fc," +curNp);        // for testing only
//                System.out.println("fc,"+curNp);        // for testing only
            }                       
        }
        int outsize = nntA[nntA.length-1].length;
        double[] temp = new double[outsize];
        Neuron[] outLayer = nntA[nntA.length-1];
        for (int i=0; i<outsize;i++){
            temp[i] = outLayer[i].output();
        }
        return temp;
    }
    /**
     * forward calculation.
     * @param sourceData input data.
     * @return output of network.
     */
    public double[] forwardCalculate(double[] sourceData){
        network.inputAdapte(sourceData);
        Neuron curNp= null;
        double[] input = new double[den];
        for ( int curl = 0; curl<nntA.length ; curl++){  
            for ( int i =0 ; i< nntA[curl].length ;i++){                
                curNp = nntA[curl][i];
                for(int cp=0;cp<den;cp++){
                    input[cp]=curNp.dendrites[cp+1].in.output();                    
                }
                curNp.input(input);
            }                       
        }

        double[] temp = new double[nntA[nntA.length-1].length];
        for (int i=0; i<nntA[nntA.length-1].length;i++){
            temp[i] = nntA[nntA.length-1][i].output();
        }
        return temp;
    }
    
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -