📄 neuraltrainer.java
字号:
/*
* NeuralTrainer.java
*
* Created on 2007年11月29日, 上午11:46
*
* To change this template, choose Tools | Template Manager
* and open the template in the editor.
*/
package Compressor;
import neuralNetwork.*;
import myUtil.*;
import java.io.*;
import java.util.*;
import javax.swing.*;
import java.awt.*;
/**
* Trainer.
* Numbers of input port is defined by the numbers of the first layer.
* @author yuhui_bear
*/
public class NeuralTrainer implements Runnable{
private BackPropagation network;
private int[] netTop;
private int branch;
private float STEP =0.01F , MOMENTUM = 0.3F;
private final double ADJUST =-0.000001D;
private int TOTAL_VALUES = 16; //BATCH_SWATCH_SIZE = TOTAL_VALUES ^ 4 / subGROUP
private int subGROUP ;
private int BATCH_SWATCH_SIZE ;
private double goal = 0.00001;
private String[] SwatchInf;
private String netinf;
private JProgressBar turnProgressBar;
private Diagram graph;
private int recordInterval ;
/**
* Creates a new instance of NeuralTrainer
* @param top this array stored numbers of neuron in every layer.
* @param branches dendrities of neuron.
* @param aim aimming MSE of training.
* @param numbersOfValues numbers Of Values.
* @param show diagram to show MSE.
* @param stpb progress of batch.
*/
public NeuralTrainer(int[] top,int branches ,double aim ,int numbersOfValues,int ri,Diagram show,JProgressBar stpb) {
if(aim !=0){
goal = aim;
}
recordInterval = ri;
TOTAL_VALUES = numbersOfValues;
//BATCH_SWATCH_SIZE = TOTAL_VALUES ^ 4 / subGROUP
subGROUP = (int)Math.pow(TOTAL_VALUES,top[0] /2);
BATCH_SWATCH_SIZE =(int)(Math.pow(TOTAL_VALUES,top[0]) / subGROUP);
turnProgressBar = stpb;
netTop = top;
network= new BackPropagation(top,branches);
branch =branches;
graph = show;
netinf="["+String.valueOf(branch)+"]";
for(int i =0;i<netTop.length;i++){
netinf = netinf + "-" + String.valueOf(netTop[i]);
}
}
/**
* inherited for Runnable Interface.
*/
public void run(){
Swatch trainLoader = new Swatch(TOTAL_VALUES,subGROUP,netTop[0]);
SwatchInf = trainLoader.printStandardValues();
double[] output = new double[netTop[netTop.length /2]];
double[] mse = new double[subGROUP];
double[][] inputdata = new double[subGROUP][netTop[0]];
double mseOfBatch =0 , aimMSE =goal;
float istep =STEP, imomentum = MOMENTUM;
int trainCNT =0 ,modifyCNT=0;
int[] trainSequence ;
try{
String filename ="log_"+netinf;
File logfile = new File(filename + ".csv");
int fcnt= 0;
while(logfile.exists()){
filename = "log_"+netinf+"_" + fcnt;
fcnt ++;
logfile = new File(filename +".csv");
}
PrintWriter outm = new PrintWriter(new BufferedWriter(new FileWriter(logfile)),true);
MyRand randSqeuence = new MyRand(subGROUP);
MyRand randGroup = new MyRand(BATCH_SWATCH_SIZE);
double[][][] databuf = new double[BATCH_SWATCH_SIZE][][];
for (int curG =0; curG< BATCH_SWATCH_SIZE;curG++){
databuf[curG] = trainLoader.nextGroup();
}
do{
//turn begin.
mseOfBatch =0;
int gCNT=0;
for (int curG :randGroup.randSequence(BATCH_SWATCH_SIZE)){
inputdata = databuf[curG];
trainSequence = randSqeuence.randSequence(subGROUP);
//sub group
gCNT++;
if(gCNT >BATCH_SWATCH_SIZE){
gCNT=0;
}
for (int t: trainSequence){
output=network.forwardCalculate(inputdata[t]);
network.backwardCalculate(inputdata[t],output,istep,imomentum);
mse[t] = Evaluate.averageSquareError(inputdata[t],output);
}//sub group done.
turnProgressBar.setValue((int)(100 * gCNT / BATCH_SWATCH_SIZE));
mseOfBatch += Evaluate.mean(mse);
}// complete turn.
if( trainCNT % recordInterval==0){
outm.print("TrainCNT:,"+trainCNT+",mrse:,"+mseOfBatch+",In/Out,");
for(int i = 0 ; i< netTop[0];i++){
outm.print(inputdata[6][i]+","+output[i]+",");
}
outm.println("");
}
trainCNT++;
modifyCNT ++;
mseOfBatch = mseOfBatch / BATCH_SWATCH_SIZE;
if( modifyCNT > 128){
double d= Math.pow(2, ADJUST* trainCNT);
istep =(float) (STEP * d);
imomentum = (float) (MOMENTUM * d);
modifyCNT=0;
}
graph.redraw(mseOfBatch ,"Aim = "+goal+"/CNT="+trainCNT,"Step="+istep+"/MOMENTUM ="+MOMENTUM);
}while (mseOfBatch > goal);
turnProgressBar.setValue(100);
outm.close();
saveResult(network.network.neuralNet_A ,aimMSE, mseOfBatch ,trainCNT);
graph.redraw(mseOfBatch ,"CNT="+trainCNT,"Trainning done !");
}catch(IOException ex){
graph.redraw(mseOfBatch ,"CNT="+trainCNT,"No log file !!!");
}
}
/**
* stop training.
*/
public void stopTrain(){
goal =100;
}
/**
* set new aim.
*/
public void setAim(double am){
goal=am;
}
/**
* set new step.
*/
public void setStep(float st){
STEP=st;
}
private void saveResult(Neuron[][] net,double aimMSE, double mseLast, int cnt){
PrintWriter wout;
try {
//prepare recorde file.
String filename ="NerualNet_"+netinf+"_layers";
File outfile = new File(filename + ".xml");
int fcnt= 0;
while(outfile.exists()){
filename = "NerualNet_"+netinf+"_layers_" + fcnt;
fcnt ++;
outfile = new File(filename +".xml");
}
wout = new PrintWriter(new BufferedWriter(new FileWriter(outfile)), true);
wout.println("<Neural Network>");
wout.println("<Aim MSE>"+aimMSE+"</Aim MSE>");
wout.println("<Last Mean Squear Error>"+ mseLast+"</Last Mean Squear Error>");
wout.println("<Swatch Space>" +Math.pow(TOTAL_VALUES ,4) + "</Swatch Space>");
wout.println("<Trained turns>"+ cnt+"</Trained turns>");
wout.println("<Training Standard Values>");
for (int i = 0;i<SwatchInf.length;i++){
wout.println("<["+i+"]>"+SwatchInf[i]+"</["+i+"]>");
}
wout.println("</Training Standard Values>");
wout.println("<Title>weight of neural network</Title>");
for (int lay =0;lay<net.length;lay++){
wout.println("<lay>" +lay);
for(int cur =0; cur<net[lay].length; cur++){
wout.print("<Neuron>"+cur +"<weight>" );
for(int iw =0 ; iw<branch +1;iw++){
wout.print(net[lay][cur].dendrites[iw].weight+",");
}
wout.print("</weight></Neuron>\n");
}
wout.println("</lay>");
}
wout.println("</Neural Network>");
wout.close();
} catch (IOException ex) {
ex.printStackTrace();
System.exit(-1);
}
}
}
/**
*generate swatch for network with 4 input port neurons and custumed numbers of the input layer.
*/
class Swatch {
private int numbersOfValue;
private int swatchSize;
private int dendrities;
private int[] Indexb;
private double[] standardValue ;
private double DIF;
/**
* load a group of swatch from swatch.
* @param group ,total numbers of value in the Range.
* @param swatchsize ,size of swatch.
* @param branch ,dendrities of input layer.
*/
public Swatch (int group ,int swatchsize,int branch){
numbersOfValue = group;
swatchSize = swatchsize;
dendrities = branch;
standardValue = new double[numbersOfValue];
DIF = 0.8D / (double)numbersOfValue;
Indexb = new int[dendrities];
Arrays.fill(Indexb,0);
for ( int ig=0 ; ig < numbersOfValue;ig++){
standardValue[ig] =ig * DIF +0.1258;
}
}
public String[] printStandardValues(){
String[] outs = new String[standardValue.length];
for (int i = 0 ; i < standardValue.length;i++){
outs[i] =String.valueOf( standardValue[i]);
}
return outs;
}
public double[][] nextGroup(){
int iswatch =0;
double [][] outdata= new double[swatchSize][dendrities];
switch (dendrities){
case 4:{
while (iswatch < swatchSize){
for (int b1 =Indexb[0];b1<numbersOfValue;b1++){
for (int b2 =Indexb[1];b2<numbersOfValue;b2++){
for (int b3 =Indexb[2];b3<numbersOfValue;b3++){
for (int b4 =Indexb[3];b4<numbersOfValue;b4++){
outdata[iswatch][0]=standardValue[b1];
outdata[iswatch][1]=standardValue[b2];
outdata[iswatch][2]=standardValue[b3];
outdata[iswatch][3]=standardValue[b4];
iswatch++;
if(iswatch >=swatchSize ){
Indexb[0] = b1;
Indexb[1] = b2;
Indexb[2] = b3;
Indexb[3] = b4;
return outdata;
}
} // b4
Indexb[3]=0;
} //b3
Indexb[2]=0;
} //b2
Indexb[1]=0;
} //b1
Indexb[0]=0;
}
Arrays.fill(Indexb,0);
return outdata;
}
case 8:{
while (iswatch < swatchSize){
for (int b1 =Indexb[0];b1<numbersOfValue;b1++){
for (int b2 =Indexb[1];b2<numbersOfValue;b2++){
for (int b3 =Indexb[2];b3<numbersOfValue;b3++){
for (int b4 =Indexb[3];b4<numbersOfValue;b4++){
for (int b5 =Indexb[4];b4<numbersOfValue;b4++){
for (int b6 =Indexb[5];b4<numbersOfValue;b4++){
for (int b7 =Indexb[6];b4<numbersOfValue;b4++){
for (int b8 =Indexb[7];b4<numbersOfValue;b4++){
outdata[iswatch][0]=standardValue[b1];
outdata[iswatch][1]=standardValue[b2];
outdata[iswatch][2]=standardValue[b3];
outdata[iswatch][3]=standardValue[b4];
outdata[iswatch][4]=standardValue[b5];
outdata[iswatch][5]=standardValue[b6];
outdata[iswatch][6]=standardValue[b7];
outdata[iswatch][7]=standardValue[b8];
iswatch++;
if(iswatch >=swatchSize ){
Indexb[0] = b1;
Indexb[1] = b2;
Indexb[2] = b3;
Indexb[3] = b4;
Indexb[4] = b5;
Indexb[5] = b6;
Indexb[6] = b7;
Indexb[7] = b8;
return outdata;
}
} //b8
Indexb[7]=0;
} //b7
Indexb[6]=0;
} //b6
Indexb[5]=0;
} //b5
Indexb[4]=0;
} // b4
Indexb[3]=0;
} //b3
Indexb[2]=0;
} //b2
Indexb[1]=0;
} //b1
Indexb[0]=0;
}
Arrays.fill(Indexb,0);
return outdata;
}
default:{
return null;
}
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -