⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuralfileencoder.java

📁 用java实现的神经网络bp的核心算法
💻 JAVA
字号:
package Compressor;
/*
 * NeuralFileEncoder.java
 *
 * Created on 2007骞?1鏈?鏃? 涓嬪崍3:39
 *
 */
import myUtil.*;
//import Compressor.FileAdapter;
import neuralNetwork.*;
import java.util.*;
import java.io.*;
import javax.swing.*;
/**
 * Compress.
 * @author yuhui_bear
 */
public class NeuralFileEncoder implements Runnable{
    private Diagram showB;
    private int branch =0 ;
    private int[] net;
    private File networkFile ;
    private BackPropagation compressNet;
    private BufferedFileAdapter data= null;
    private JProgressBar progressBar;
    private JButton startBt;
    private JDialog bd;
    private String patch ,netinf;
    private double goal ;
    private final float STEP =0.008F , MOMENTUM = 0.25F;
    private final double ADJUST =-0.000008D;
    /**
     * Creates a new instance of NeuralEncoder.
     * This one is just for testing.
     * Trained net will be stored in .\"weight.xml.
     * @param sb  diaplay panel for the trainer.
     * @param jb progress bar.
     * @param source file to compress.
     * @param br , Dendrites number of the neuron.
     * @param top net information .
     */
    public NeuralFileEncoder(Diagram sb ,JProgressBar jb , File source, int br ,int[] top) {
        progressBar  = jb;
        showB = sb;     //new TestShowBoard(this);
        branch =br;
        net =top;
        networkFile = new File("weight.xml");
        compressNet = new BackPropagation(net,branch);
    }
    /**
     * Creates a new instance of NeuralDecoder
     * 
     
     * @param md MSE diagram was displaied in.
     * @param sb diagram to display MSE.
     * @param jb progress bar.
     * @param bt button been disabled on starting.
     * @param source file to compress.
     * @param neuralnet neural net file.
     * @param precise aimming of training.
     */
    public NeuralFileEncoder(JDialog md,Diagram sb ,JProgressBar jb, JButton bt,File source , File neuralnet,double precise) {
        goal = precise;
        bd = md;
        progressBar  = jb;
        startBt = bt;
        showB = sb;     //new TestShowBoard(this);
        networkFile = neuralnet;
        if(networkFile ==null){
            compressNet = new BackPropagation(new int[]{8,4,8},4);
            bd.setVisible(true);
        }else{
            compressNet =new BackPropagation(networkFile);
        }
        net = compressNet.network.getNetInformation().nettop;
        branch =compressNet.network.getNetInformation().branch;
        patch =source.getPath().substring(0 , source.getPath().length()-source.getName().length());
        //open the source file . prepare output file.
        try {
            data = new BufferedFileAdapter(source,net[0],true);
        } catch (IOException ex) {
            ex.printStackTrace();
            System.exit(-1);
        }
        netinf=String.valueOf(net[0]);
        for(int i =1;i<net.length;i++){
            netinf = netinf + "-" + String.valueOf(net[i]);
        }
    }
         
    /**
     * inherited for Runnable Interface.
     */
    public void run(){
        //train output for source file.
        if (networkFile ==null){
            trainNet(compressNet , data);
        }
       //prepare training swatch.
       double[] inputdata = new double[net[0]];
       double[] encoded = new double[net[net.length /2 +1]];
       double[] outdata =new double[net[0]]; 
       int err=0;
       data.resetIterator();
       inputdata = data.nextGroup();  
       while(inputdata !=null){
           // get value ...
           outdata = compressNet.forwardCalculate(inputdata);                   
           encoded=compressNet.encode();
           err = data.writeGroup(encoded);
           if(err==2){
                   break;
               }
           if(err <0){               
               System.err.print("write back error. err = "+ err);
               System.exit(err);
           }
           progressBar.setValue(data.getProgress());
           inputdata = data.nextGroup(); 
       }
       //close the stream.
       data.writeGroup(null);
       progressBar.setValue(100);
       startBt.setText("Job Done");
       startBt.setEnabled(true);
    }
    
    private void trainNet(BackPropagation trainnet ,BufferedFileAdapter trainSwatch){
        int[] netvector = compressNet.network.getNetInformation().nettop;
        int[] group = data.subGroupSize();
        int[] trainSequen;
        double[] output =new double[0];
        double mse = 0;
        double[][] inputdata ;
        long swatchCNT=0;
        int trainCNT =0 ,modify =0 ;
        float istep =STEP, imomentum = MOMENTUM;
        try{
            String filename ="log_In_"+netinf;
            File logfile = new File(patch,filename + ".csv");
            int fcnt= 0;
            while(logfile.exists()){
                filename = "log_In_"+netinf+"_" + fcnt;
                fcnt ++;
                logfile = new File(patch,filename +".csv");
            }
            PrintWriter stateslog = new PrintWriter(new BufferedWriter(new FileWriter(logfile)),true);
            
            do{
                if( modify ==512 ){
                    double d=  Math.pow(2, ADJUST* trainCNT);             //rand.nextFloat() *2;    
                    istep =(float) (STEP * d);
                    imomentum = (float) (MOMENTUM * d);
                    modify =0;
                }          
                modify ++;
                //turn begin.      
                swatchCNT= 0;
                for(int ig : group){
                    inputdata = new double[ig][branch];
                    for(int iget=0; iget<ig;iget++){
                        inputdata[iget] = trainSwatch.nextGroup();
                    }
                    trainSequen = MyRand.randSequence(ig,ig);                    
                    //random order in sub group.
                    for(int trainLoadIndex :trainSequen){                               
                        output=trainnet.forwardCalculate(inputdata[trainLoadIndex]);
                        trainnet.backwardCalculate(inputdata[trainLoadIndex],output,istep,imomentum);
                        mse += Evaluate.averageSquareError(inputdata[trainLoadIndex],output); 
                        swatchCNT ++;
                        //log
                        if( swatchCNT==8){
                            stateslog.print("TrainCNT:,"+trainCNT+",mse:,"+mse+",In:,");
                            for(int i = 0 ; i< netvector[0];i++){
                                stateslog.print(inputdata[trainLoadIndex][i]+",");
                            }
                            stateslog.print("Out:,");
                            for(int i = 0 ; i< netvector[0];i++){
                                stateslog.print(output[i]+",");
                            }
                            stateslog.println("");
                        }//log
                    }//subgroup complete.
                }// complete one turn.
                trainSwatch.resetIterator();                
                mse = mse / swatchCNT;
                trainCNT ++;
                showB.redraw(mse ,"CNT="+trainCNT,"Aim="+goal+"/step="+istep);
            }while (mse > goal);
            stateslog.close();
            saveResult(compressNet.network.neuralNet_A , mse ,trainCNT);
            showB.redraw(mse ,"CNT="+trainCNT,"Trainning done !");
            bd.setVisible(false);
        }catch(IOException ex){
            showB.redraw(mse ,"CNT="+trainCNT,"No log file !!!");
        }        
        
        
    }
    
    /**
     * stop training.
     */
    public void stop(){
        goal =100;
    }
    
    private void saveResult(Neuron[][] net, double mseLast, int cnt){
        PrintWriter wout;        
        try {
            //prepare recorde file.            
            String filename ="NerualNet_"+netinf+"_layers";
            File outfile = new File(patch,filename + ".xml");
            int fcnt= 0;
            while(outfile.exists()){
                filename = "NerualNet_"+netinf+"_layers_" + fcnt;
                fcnt ++;
                outfile = new File(patch,filename +".xml");
            }
            
            wout = new PrintWriter(new BufferedWriter(new FileWriter(outfile)), true);                
            wout.println("<Neural Network>");
            wout.println("<Trained turns>"+ cnt+"</Trained turns>");
            wout.println("<Last Mean Squear Error>"+ mseLast+"</Last Mean Squear Error>");
            wout.println("<Iniate Step>"+ STEP+"</Iniate Step>");
            wout.println("<Step Adujst Exponent>"+ ADJUST+"</Step Adujst Exponent>");
            wout.println("<Momentum>"+ MOMENTUM+"</Momentum>");
            wout.println("<Title>weight of neural network</Title>");
            for (int lay =0;lay<net.length;lay++){
                wout.println("<lay>" +lay);
                for(int cur =0; cur<net[lay].length; cur++){
                    wout.print("<Neuron>"+cur +"<weight>" );
                    for(int iw =0 ; iw<branch +1;iw++){
                        wout.print(net[lay][cur].dendrites[iw].weight+",");
                    }
                    wout.print("</weight></Neuron>\n");
                }
                wout.println("</lay>");
            }
            wout.println("</Neural Network>");
            wout.close();
            } catch (IOException ex) {
            ex.printStackTrace();
            System.exit(-1);
        }
    }       
}


⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -