📄 neuralfileencoder.java
字号:
package Compressor;
/*
* NeuralFileEncoder.java
*
* Created on 2007骞?1鏈?鏃? 涓嬪崍3:39
*
*/
import myUtil.*;
//import Compressor.FileAdapter;
import neuralNetwork.*;
import java.util.*;
import java.io.*;
import javax.swing.*;
/**
* Compress.
* @author yuhui_bear
*/
public class NeuralFileEncoder implements Runnable{
private Diagram showB;
private int branch =0 ;
private int[] net;
private File networkFile ;
private BackPropagation compressNet;
private BufferedFileAdapter data= null;
private JProgressBar progressBar;
private JButton startBt;
private JDialog bd;
private String patch ,netinf;
private double goal ;
private final float STEP =0.008F , MOMENTUM = 0.25F;
private final double ADJUST =-0.000008D;
/**
* Creates a new instance of NeuralEncoder.
* This one is just for testing.
* Trained net will be stored in .\"weight.xml.
* @param sb diaplay panel for the trainer.
* @param jb progress bar.
* @param source file to compress.
* @param br , Dendrites number of the neuron.
* @param top net information .
*/
public NeuralFileEncoder(Diagram sb ,JProgressBar jb , File source, int br ,int[] top) {
progressBar = jb;
showB = sb; //new TestShowBoard(this);
branch =br;
net =top;
networkFile = new File("weight.xml");
compressNet = new BackPropagation(net,branch);
}
/**
* Creates a new instance of NeuralDecoder
*
* @param md MSE diagram was displaied in.
* @param sb diagram to display MSE.
* @param jb progress bar.
* @param bt button been disabled on starting.
* @param source file to compress.
* @param neuralnet neural net file.
* @param precise aimming of training.
*/
public NeuralFileEncoder(JDialog md,Diagram sb ,JProgressBar jb, JButton bt,File source , File neuralnet,double precise) {
goal = precise;
bd = md;
progressBar = jb;
startBt = bt;
showB = sb; //new TestShowBoard(this);
networkFile = neuralnet;
if(networkFile ==null){
compressNet = new BackPropagation(new int[]{8,4,8},4);
bd.setVisible(true);
}else{
compressNet =new BackPropagation(networkFile);
}
net = compressNet.network.getNetInformation().nettop;
branch =compressNet.network.getNetInformation().branch;
patch =source.getPath().substring(0 , source.getPath().length()-source.getName().length());
//open the source file . prepare output file.
try {
data = new BufferedFileAdapter(source,net[0],true);
} catch (IOException ex) {
ex.printStackTrace();
System.exit(-1);
}
netinf=String.valueOf(net[0]);
for(int i =1;i<net.length;i++){
netinf = netinf + "-" + String.valueOf(net[i]);
}
}
/**
* inherited for Runnable Interface.
*/
public void run(){
//train output for source file.
if (networkFile ==null){
trainNet(compressNet , data);
}
//prepare training swatch.
double[] inputdata = new double[net[0]];
double[] encoded = new double[net[net.length /2 +1]];
double[] outdata =new double[net[0]];
int err=0;
data.resetIterator();
inputdata = data.nextGroup();
while(inputdata !=null){
// get value ...
outdata = compressNet.forwardCalculate(inputdata);
encoded=compressNet.encode();
err = data.writeGroup(encoded);
if(err==2){
break;
}
if(err <0){
System.err.print("write back error. err = "+ err);
System.exit(err);
}
progressBar.setValue(data.getProgress());
inputdata = data.nextGroup();
}
//close the stream.
data.writeGroup(null);
progressBar.setValue(100);
startBt.setText("Job Done");
startBt.setEnabled(true);
}
private void trainNet(BackPropagation trainnet ,BufferedFileAdapter trainSwatch){
int[] netvector = compressNet.network.getNetInformation().nettop;
int[] group = data.subGroupSize();
int[] trainSequen;
double[] output =new double[0];
double mse = 0;
double[][] inputdata ;
long swatchCNT=0;
int trainCNT =0 ,modify =0 ;
float istep =STEP, imomentum = MOMENTUM;
try{
String filename ="log_In_"+netinf;
File logfile = new File(patch,filename + ".csv");
int fcnt= 0;
while(logfile.exists()){
filename = "log_In_"+netinf+"_" + fcnt;
fcnt ++;
logfile = new File(patch,filename +".csv");
}
PrintWriter stateslog = new PrintWriter(new BufferedWriter(new FileWriter(logfile)),true);
do{
if( modify ==512 ){
double d= Math.pow(2, ADJUST* trainCNT); //rand.nextFloat() *2;
istep =(float) (STEP * d);
imomentum = (float) (MOMENTUM * d);
modify =0;
}
modify ++;
//turn begin.
swatchCNT= 0;
for(int ig : group){
inputdata = new double[ig][branch];
for(int iget=0; iget<ig;iget++){
inputdata[iget] = trainSwatch.nextGroup();
}
trainSequen = MyRand.randSequence(ig,ig);
//random order in sub group.
for(int trainLoadIndex :trainSequen){
output=trainnet.forwardCalculate(inputdata[trainLoadIndex]);
trainnet.backwardCalculate(inputdata[trainLoadIndex],output,istep,imomentum);
mse += Evaluate.averageSquareError(inputdata[trainLoadIndex],output);
swatchCNT ++;
//log
if( swatchCNT==8){
stateslog.print("TrainCNT:,"+trainCNT+",mse:,"+mse+",In:,");
for(int i = 0 ; i< netvector[0];i++){
stateslog.print(inputdata[trainLoadIndex][i]+",");
}
stateslog.print("Out:,");
for(int i = 0 ; i< netvector[0];i++){
stateslog.print(output[i]+",");
}
stateslog.println("");
}//log
}//subgroup complete.
}// complete one turn.
trainSwatch.resetIterator();
mse = mse / swatchCNT;
trainCNT ++;
showB.redraw(mse ,"CNT="+trainCNT,"Aim="+goal+"/step="+istep);
}while (mse > goal);
stateslog.close();
saveResult(compressNet.network.neuralNet_A , mse ,trainCNT);
showB.redraw(mse ,"CNT="+trainCNT,"Trainning done !");
bd.setVisible(false);
}catch(IOException ex){
showB.redraw(mse ,"CNT="+trainCNT,"No log file !!!");
}
}
/**
* stop training.
*/
public void stop(){
goal =100;
}
private void saveResult(Neuron[][] net, double mseLast, int cnt){
PrintWriter wout;
try {
//prepare recorde file.
String filename ="NerualNet_"+netinf+"_layers";
File outfile = new File(patch,filename + ".xml");
int fcnt= 0;
while(outfile.exists()){
filename = "NerualNet_"+netinf+"_layers_" + fcnt;
fcnt ++;
outfile = new File(patch,filename +".xml");
}
wout = new PrintWriter(new BufferedWriter(new FileWriter(outfile)), true);
wout.println("<Neural Network>");
wout.println("<Trained turns>"+ cnt+"</Trained turns>");
wout.println("<Last Mean Squear Error>"+ mseLast+"</Last Mean Squear Error>");
wout.println("<Iniate Step>"+ STEP+"</Iniate Step>");
wout.println("<Step Adujst Exponent>"+ ADJUST+"</Step Adujst Exponent>");
wout.println("<Momentum>"+ MOMENTUM+"</Momentum>");
wout.println("<Title>weight of neural network</Title>");
for (int lay =0;lay<net.length;lay++){
wout.println("<lay>" +lay);
for(int cur =0; cur<net[lay].length; cur++){
wout.print("<Neuron>"+cur +"<weight>" );
for(int iw =0 ; iw<branch +1;iw++){
wout.print(net[lay][cur].dendrites[iw].weight+",");
}
wout.print("</weight></Neuron>\n");
}
wout.println("</lay>");
}
wout.println("</Neural Network>");
wout.close();
} catch (IOException ex) {
ex.printStackTrace();
System.exit(-1);
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -