⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuraltrainer.java

📁 用java实现的神经网络bp的核心算法
💻 JAVA
字号:
/*
 * NeuralTrainer.java
 *
 * Created on 2007年11月29日, 上午11:46
 *
 * To change this template, choose Tools | Template Manager
 * and open the template in the editor.
 */

package Compressor;
import neuralNetwork.*;
import myUtil.*;
import java.io.*;
import java.util.*;
import javax.swing.*;
import java.awt.*;
/**
 * Trainer.
 * Numbers of input port is defined by the numbers of the first layer.
 * @author yuhui_bear
 */
public class NeuralTrainer implements Runnable{
    private BackPropagation network;
    private int[] netTop;
    private int branch;
    private float STEP =0.01F , MOMENTUM = 0.3F;
    private final double ADJUST =-0.000001D;
    private int TOTAL_VALUES = 16;        //BATCH_SWATCH_SIZE = TOTAL_VALUES ^ 4 / subGROUP
    private int subGROUP ;
    private int BATCH_SWATCH_SIZE ;
    private double goal = 0.00001;
    private String[] SwatchInf;
    private String netinf;
    private JProgressBar turnProgressBar;
    private Diagram graph;
    private int recordInterval ;
    
    /**
     * Creates a new instance of NeuralTrainer
     * @param top this array stored numbers of neuron in every layer.
     * @param branches dendrities of neuron.
     * @param aim aimming MSE of training.
     * @param numbersOfValues numbers Of Values.
     * @param show diagram to show MSE.
     * @param stpb progress of batch.
     */
    public NeuralTrainer(int[] top,int branches ,double aim ,int numbersOfValues,int ri,Diagram show,JProgressBar stpb) {
        if(aim !=0){
            goal = aim;
        }
        recordInterval = ri;
        TOTAL_VALUES = numbersOfValues;
        //BATCH_SWATCH_SIZE = TOTAL_VALUES ^ 4 / subGROUP
        subGROUP = (int)Math.pow(TOTAL_VALUES,top[0] /2);
        BATCH_SWATCH_SIZE =(int)(Math.pow(TOTAL_VALUES,top[0]) / subGROUP);
        
        turnProgressBar = stpb;
        netTop = top;
        network= new BackPropagation(top,branches);
        
        branch =branches;
        graph = show;
        netinf="["+String.valueOf(branch)+"]";
        for(int i =0;i<netTop.length;i++){
            netinf = netinf + "-" + String.valueOf(netTop[i]);
        }
    }

    /**
     * inherited for Runnable Interface.
     */
    public void run(){
        Swatch trainLoader = new Swatch(TOTAL_VALUES,subGROUP,netTop[0]);
        SwatchInf = trainLoader.printStandardValues();
        double[] output = new double[netTop[netTop.length /2]];
        double[] mse = new double[subGROUP];
        double[][] inputdata = new double[subGROUP][netTop[0]];
        double mseOfBatch =0 , aimMSE =goal;
        float istep =STEP, imomentum = MOMENTUM;
        int trainCNT =0 ,modifyCNT=0;
        int[] trainSequence ;
        try{
            String filename ="log_"+netinf;
            File logfile = new File(filename + ".csv");
            int fcnt= 0;
            while(logfile.exists()){
                filename = "log_"+netinf+"_" + fcnt;
                fcnt ++;
                logfile = new File(filename +".csv");
            }
            PrintWriter outm = new PrintWriter(new BufferedWriter(new FileWriter(logfile)),true);
            MyRand randSqeuence = new MyRand(subGROUP);
            MyRand randGroup = new MyRand(BATCH_SWATCH_SIZE);

            double[][][] databuf = new double[BATCH_SWATCH_SIZE][][];
            for (int curG =0; curG< BATCH_SWATCH_SIZE;curG++){
                databuf[curG] = trainLoader.nextGroup();
            }
            do{                         
                //turn begin.
                mseOfBatch =0;
                int gCNT=0;
                for (int curG :randGroup.randSequence(BATCH_SWATCH_SIZE)){
                    inputdata = databuf[curG];
                    trainSequence = randSqeuence.randSequence(subGROUP);
                    //sub group
                    gCNT++;
                    if(gCNT >BATCH_SWATCH_SIZE){
                        gCNT=0;
                    }
                    for (int t: trainSequence){                    
                        output=network.forwardCalculate(inputdata[t]);
                        network.backwardCalculate(inputdata[t],output,istep,imomentum);
                        mse[t] = Evaluate.averageSquareError(inputdata[t],output); 
                    }//sub group done.
                    turnProgressBar.setValue((int)(100 * gCNT / BATCH_SWATCH_SIZE));
                    mseOfBatch += Evaluate.mean(mse);                    
                }// complete turn.   
                if( trainCNT % recordInterval==0){
                        outm.print("TrainCNT:,"+trainCNT+",mrse:,"+mseOfBatch+",In/Out,");
                        for(int i = 0 ; i< netTop[0];i++){
                            outm.print(inputdata[6][i]+","+output[i]+",");
                        }
                        outm.println("");
                }
                trainCNT++;
                modifyCNT ++;
                mseOfBatch = mseOfBatch / BATCH_SWATCH_SIZE;                
                if( modifyCNT > 128){
                    double d=  Math.pow(2, ADJUST* trainCNT);             
                    istep =(float) (STEP * d);
                    imomentum = (float) (MOMENTUM * d);
                    modifyCNT=0;
                }                                 
                graph.redraw(mseOfBatch ,"Aim = "+goal+"/CNT="+trainCNT,"Step="+istep+"/MOMENTUM ="+MOMENTUM);
            }while (mseOfBatch > goal);
            turnProgressBar.setValue(100);
            outm.close();
            saveResult(network.network.neuralNet_A ,aimMSE, mseOfBatch ,trainCNT);
            graph.redraw(mseOfBatch ,"CNT="+trainCNT,"Trainning done !");
        }catch(IOException ex){
            graph.redraw(mseOfBatch ,"CNT="+trainCNT,"No log file !!!");
        }        
        
    }
    
    /**
     * stop training.
     */
    public void stopTrain(){
        goal =100;
    }
    /**
     * set new aim.
     */
    public void setAim(double am){
        goal=am;
    }    
    /**
     * set new step.
     */
    public void setStep(float st){
        STEP=st;
    }
    private void saveResult(Neuron[][] net,double aimMSE, double mseLast, int cnt){
        PrintWriter wout;        
        try {
            //prepare recorde file.            
            String filename ="NerualNet_"+netinf+"_layers";
            File outfile = new File(filename + ".xml");
            int fcnt= 0;
            while(outfile.exists()){
                filename = "NerualNet_"+netinf+"_layers_" + fcnt;
                fcnt ++;
                outfile = new File(filename +".xml");
            }
            
            wout = new PrintWriter(new BufferedWriter(new FileWriter(outfile)), true);                
            wout.println("<Neural Network>");
            wout.println("<Aim MSE>"+aimMSE+"</Aim MSE>");
            wout.println("<Last Mean Squear Error>"+ mseLast+"</Last Mean Squear Error>");
            wout.println("<Swatch Space>" +Math.pow(TOTAL_VALUES ,4) + "</Swatch Space>");
            wout.println("<Trained turns>"+ cnt+"</Trained turns>");
            wout.println("<Training Standard Values>");
            for (int i = 0;i<SwatchInf.length;i++){
                wout.println("<["+i+"]>"+SwatchInf[i]+"</["+i+"]>");
            }
            wout.println("</Training Standard Values>");
            wout.println("<Title>weight of neural network</Title>");
            for (int lay =0;lay<net.length;lay++){
                wout.println("<lay>" +lay);
                for(int cur =0; cur<net[lay].length; cur++){
                    wout.print("<Neuron>"+cur +"<weight>" );
                    for(int iw =0 ; iw<branch +1;iw++){
                        wout.print(net[lay][cur].dendrites[iw].weight+",");
                    }
                    wout.print("</weight></Neuron>\n");
                }
                wout.println("</lay>");
            }
            wout.println("</Neural Network>");
            wout.close();
            } catch (IOException ex) {
            ex.printStackTrace();
            System.exit(-1);
            }
    }
}
/**
 *generate swatch for network with 4 input port neurons and custumed numbers of the input layer.
 */
class Swatch {
    private int numbersOfValue;
    private int swatchSize;
    private int dendrities;
    private int[] Indexb;
    private double[] standardValue ;
    private double DIF;
    
    /**
     * load a group of swatch from swatch.
     * @param group ,total numbers of value in the Range.
     * @param swatchsize ,size of swatch.
     * @param branch ,dendrities of input layer.
     */
    public Swatch (int group ,int swatchsize,int branch){
        numbersOfValue = group;
        swatchSize = swatchsize;
        dendrities = branch;
        standardValue = new double[numbersOfValue];
        DIF = 0.8D / (double)numbersOfValue;
        Indexb = new int[dendrities];
        Arrays.fill(Indexb,0);

        for ( int ig=0 ; ig < numbersOfValue;ig++){
            standardValue[ig] =ig * DIF  +0.1258;
        }        
    }
    public String[] printStandardValues(){
        String[] outs = new String[standardValue.length];
        for (int i = 0 ; i < standardValue.length;i++){
            outs[i] =String.valueOf( standardValue[i]);
        }
        return outs;
    }
    
    public double[][] nextGroup(){
        int iswatch =0;
        double [][] outdata= new double[swatchSize][dendrities];
        switch (dendrities){
            case 4:{
                while (iswatch < swatchSize){
                    for (int b1 =Indexb[0];b1<numbersOfValue;b1++){
                        for (int b2 =Indexb[1];b2<numbersOfValue;b2++){
                            for (int b3 =Indexb[2];b3<numbersOfValue;b3++){
                                for (int b4 =Indexb[3];b4<numbersOfValue;b4++){
                                    outdata[iswatch][0]=standardValue[b1];
                                    outdata[iswatch][1]=standardValue[b2];
                                    outdata[iswatch][2]=standardValue[b3];
                                    outdata[iswatch][3]=standardValue[b4];
                                    iswatch++;
                                    if(iswatch >=swatchSize ){
                                        Indexb[0] = b1;
                                        Indexb[1] = b2;
                                        Indexb[2] = b3;
                                        Indexb[3] = b4;
                                        return outdata;
                                    }
                                }   // b4
                                Indexb[3]=0;                    
                            }   //b3
                            Indexb[2]=0;
                        }   //b2
                        Indexb[1]=0;
                    }   //b1
                    Indexb[0]=0;
                }
                Arrays.fill(Indexb,0);
                return outdata;   
            }
            case 8:{
                while (iswatch < swatchSize){
                    for (int b1 =Indexb[0];b1<numbersOfValue;b1++){
                        for (int b2 =Indexb[1];b2<numbersOfValue;b2++){
                            for (int b3 =Indexb[2];b3<numbersOfValue;b3++){
                                for (int b4 =Indexb[3];b4<numbersOfValue;b4++){
                                    for (int b5 =Indexb[4];b4<numbersOfValue;b4++){
                                        for (int b6 =Indexb[5];b4<numbersOfValue;b4++){
                                            for (int b7 =Indexb[6];b4<numbersOfValue;b4++){
                                                for (int b8 =Indexb[7];b4<numbersOfValue;b4++){
                                                    outdata[iswatch][0]=standardValue[b1];
                                                    outdata[iswatch][1]=standardValue[b2];
                                                    outdata[iswatch][2]=standardValue[b3];
                                                    outdata[iswatch][3]=standardValue[b4];
                                                    outdata[iswatch][4]=standardValue[b5];
                                                    outdata[iswatch][5]=standardValue[b6];
                                                    outdata[iswatch][6]=standardValue[b7];
                                                    outdata[iswatch][7]=standardValue[b8];
                                                    iswatch++;
                                                    if(iswatch >=swatchSize ){
                                                        Indexb[0] = b1;
                                                        Indexb[1] = b2;
                                                        Indexb[2] = b3;
                                                        Indexb[3] = b4;
                                                        Indexb[4] = b5;
                                                        Indexb[5] = b6;
                                                        Indexb[6] = b7;
                                                        Indexb[7] = b8;
                                                        return outdata;
                                                    }
                                                }   //b8
                                                Indexb[7]=0;  
                                            }   //b7
                                            Indexb[6]=0;  
                                        }   //b6
                                        Indexb[5]=0;  
                                    }   //b5
                                    Indexb[4]=0;  
                                }   // b4
                                Indexb[3]=0;                    
                            }   //b3
                            Indexb[2]=0;
                        }   //b2
                        Indexb[1]=0;
                    }   //b1
                    Indexb[0]=0;
                }
                Arrays.fill(Indexb,0);
                return outdata;             
            }
            default:{
                return null;
            }
        }
    }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -