📄 mlp.java
字号:
/**BackProp - a backpropagation neural network class 4/98 by * <a href=mailto:tiscione@hhs.net>Jason Tiscione</a>. * <br>Copyright (c) 1998. All Rights Reserved. You have a non-exclusive, * royalty free, LICENSE to use, modify and redistribute this software * in source and binary code form, provided that 1) this dubious legal * junk at the top appears on all copies of the software; and * ii) you don't use the software in a manner which is "disparaging" * to me- like making fun of the way I program. * This software is provided "AS IS," with no warranty of any kind. * NO MATTER WHAT, you can't sue me for anything. This software is * neither designed nor intended to be used for air traffic control * or for maintaining a nuclear facility, so don't get any crazy ideas. * * <br>SUMMARY OF CLASS MEMBER VARIABLES: * <br>int rawDim = size (# elements) of an input vector accepted by network * <br>int nS1,nS2,nS3 = number of neurons in input, hidden, output layers * <br>double wmx1[][],wmx2[][],wmx3[][] = weight matrices of input, hidden, output layers * <br>double b1[][],b2[][],b3[][] = bias values of input, hidden, output layers * <br>int epoch= number of training presentations so far * <br>double learnRate= learning rate [default 0.01] * <br>double incLR= learning rate increase [default 1.05] * <br>double decLR= learning rate decrease [default 0.7] * <br>double moment= momentum coefficient [default 0.9] * <br>double error= maximum error ratio [default 1.04] * <br> (The following variables are not used until trainInit() is called) * <br>int Q = number of training inputs * <br>double SSE = sum squared error * <br>double MC = current momentum * <br>double out1[][],out2[][],out3[][] = output from input, hidden, output layers * <br>double d1[][],d2[][],d3[][] = delta log array, for calculating dWn and dbn * <br>double dW1[][],dW2[][],dW3[][] = weight matrix differentials * <br>double db1[][],db2[][],db3[][] = bias matrix differentials * <br>double errors[][] = error matrix * <br>double inputs[][] = training inputs matrix * <br>double targets[][] = training targets matrix (usually an nxn identity matrix).<br> * @author <a href="mailto:tiscione@hhs.net>Jason Tiscione</a> * @version 1.0. No future revisions planned. */public class MLP extends Object { private int rawDim,nS1,nS2,nS3,epoch; private double[][] wmx1,wmx2,wmx3,b1,b2,b3; private double learnRate,incLR,decLR,moment,error; int Q; //number of training inputs private double SSE; //Sum squared error private double MC; //current momentum private double[][] out1,out2,out3,errors; private double[][] d1,d2,d3,dW1,dW2,dW3,db1,db2,db3; private double[][] inputs,targets; /**Creates a new BackProp object, and initializes its weight & bias arrays.@param rawDim Size (number of elements) of an input vector accepted by network@param nS1 Number of neurons in input layer@param nS2 Number of neurons in hidden layer@param nS3 Number of neurons in output layer@param learnRate Learning rate.@param incLR Learning rate increase.@param decLR Learning rate decrease.@param moment Momentum coefficient@param error Maximum error ratio@return BackProp object with specified properties */ public MLP(int rawDim, int nS1, int nS2, int nS3, double learnRate, double incLR, double decLR, double moment, double error) { this.rawDim = rawDim; this.nS1 = nS1; this.nS2 = nS2; this.nS3=nS3; this.learnRate = learnRate; this.incLR = incLR; this.decLR = decLR; this.moment = moment; this.error = error; //Call Nguyen-Widrow random initializer for log-sigmoid neurons. InitWB(nS1, rawDim, 1); InitWB(nS2, nS1, 2); InitWB(nS3, nS2, 3); epoch = 0; } /**This constructor simply takes as parameters the number of elements ineach input vector and the number of neurons in the input, hidden, andoutput layers. It invokes the main constructor with default values forits remaining parameters.@param rawDim Size (# elements) of an input vector accepted by network@param nS1 Number of neurons in input layer@param nS2 Number of neurons in hidden layer@param nS3 Number of neurons in output layer@return BackProp object with specified properties */ public MLP(int rawDim, int nS1, int nS2, int nS3) { this(rawDim, nS1, nS2, nS3, 0.01, 1.05, 0.7, 0.9, 1.04); } /**Used to test the network on a single given input. It is merely aconvenience method; this version of testNet() - there are two - takesa 1-dimensional array (containing a single input vector) as its soleargument, converts it from double[] to double[][], and passes it tothe other version of testNet (which should preferably be usedinstead of this one.) The argument to this method should containrawDim elements. It returns a double[][] with nS3 rows and one column.@param inputVector A single input vector, with rawDim elements.@return Matrix of network outputs. */// public double[][] testNet( double[] inputVector ) { public double[] testNet( double[] inputVector ) { //changed by wwwang double[][] inputarray = new double[rawDim][1]; for (int i=0; i<rawDim; i++) { inputarray[i][0] = inputVector[i]; } double[][] outputMatrix = testNet(inputarray); double[] outputVector = new double[nS3]; for(int i=0; i<nS3; i++) { outputVector[i] = outputMatrix[i][0]; } return outputVector; } /**Used to test the network on a given input or inputs. The argumentis an array of inputs. It must have rawDim rows. The number of columns isarbitrary but must be at least one. This method will return an outputarray with nS3 rows and the same number of columns that the double[][]argument has.<p>Inputs to the network are constrained to the interval [0,1]; to usethis class with inputs that span a different range, a scaling algorithmwould need to be applied to the inputs and targets before calling thismethod or trainInit().<p>The trainInit() method (at least) should be called before invokingtestNet().@param test_inputs A double[][] matrix of network inputs@return A double[][] matrix of network outputs */ public double[][] testNet( double[][] test_inputs ) { double[][] test_ip, test_out1, test_out2, test_out3; test_ip = multiply(wmx1, test_inputs); test_out1 = logsig(test_ip, b1); test_ip = multiply(wmx2, test_out1); test_out2 = logsig(test_ip, b2); test_ip = multiply(wmx3, test_out2); test_out3 = logsig(test_ip, b3); return test_out3; } /**Establishes what the training inputs and targets are to be,and prepares the network for subsequent calls to trainNet(). Thismethod need be called only once.<p>The inputs[][] variable is set to the first argument, and so it musthave rawDim rows and a number of columns which is at least one. The number ofcolumns establishes Q, the number of inputs. The targets[][] variable isset to the second argument and so that must have nS3 rows and Q columns.If the number of columns in the two arguments are not the same, or ifthe arguments do not have rawDim and nS3 rows respectively, the method returnsfalse. The individual input vectors and corresponding target vectorsused to train the network are thus the column vectors in the inputs[][]and targets[][] arrays. All neuron activations in this network areconstrained to the interval [0,1] by a log-sigmoid function.Inputs to the network are also constrained to this interval; to usethis class with a set of inputs that span a different range, ascaling algorithm would need to be applied to the inputs and targetsbefore calling this method or testNet().@param inputs A double[][] matrix of network inputs@param targets A double[][] matrix of target network outputs@return A boolean value indicating successful invocation. */ public boolean trainInit(double[][] inputs, double[][] targets ) { // First initialize inputs and targets. this.inputs = inputs; this.targets = targets; // First create arrays dW1,dW2,dW3 to be same size as wmx1,wmx2,wmx3 arrays. // Java initializes all elements to zero so we don't need to. double[][] ip; //used for inner products dW1 = new double[nS1][rawDim]; dW2 = new double[nS2][nS1]; dW3 = new double[nS3][nS2]; db1 = new double[nS1][1]; db2 = new double[nS2][1]; db3 = new double[nS3][1]; MC=0; int[] dim_inputs = getSize(inputs); int[] dim_targets = getSize(targets); Q = dim_inputs[1]; errors = new double[nS3][Q]; if (dim_targets[1] != Q) { System.out.println("TrainNet: Inputs and targets are mismatched."); return false; } if (dim_inputs[0] != rawDim) { System.out.println("TrainNet: Inputs do not match net structure."); return false; } if (dim_targets[0] != nS3) { System.out.println("TrainNet: Targets do not match net structure."); return false; } //PRESENTATION PHASE ip = multiply(wmx1,inputs); out1 = logsig(ip,b1); ip = multiply(wmx2,out1); out2 = logsig(ip,b2); ip = multiply(wmx3,out2); out3 = logsig(ip,b3); //Compute errors matrix and sum-squared error (SSE). SSE = 0; for (int i=0; i<nS3; i++) { for (int j=0; j<Q; j++) { errors[i][j] = targets[i][j] - out3[i][j]; SSE += errors[i][j]*errors[i][j]; } } //BACKPROPAGATION PHASE d3 = deltalog(out3,errors); d2 = deltalog(out2,d3,wmx3); d1 = deltalog(out1,d2,wmx2); return true; } /**Performs ONE iteration (i.e., one epoch) of the backpropagationalgorithm on the network. It updates the weight and bias arrays ofthe network, and returns the sum squared error of the network'soutput in comparison to the target array.@return The SSE (sum-sqared error) of the network's output. */ public double trainNet() { //Returns SSE after 1 epoch of training on Q input pattern vectors. double new_SSE; double[][] new_out1, new_out2, new_out3, new_errors; double[][] ip; // Used for inner products double[][] swap; //Used for swapping references new_errors = new double[nS3][Q]; // Create arrays new_W1,new_W2,new_W3 // and new_b1,new_b2,new_b3. double new_W1[][] = new double[nS1][rawDim]; double new_W2[][] = new double[nS2][nS1]; double new_W3[][] = new double[nS3][nS2]; double new_b1[][] = new double[nS1][1]; double new_b2[][] = new double[nS2][1]; double new_b3[][] = new double[nS3][1]; epoch++; // LEARNING PHASE // Calculate dW arrays and db arrays // dW1: for (int i=0; i<nS1; i++) { for (int j=0; j<rawDim; j++) { dW1[i][j] *= MC; // momentum term for (int k=0; k<Q; k++) { dW1[i][j] += learnRate * (1-MC) * d1[i][k] * inputs[j][k]; } } } // db1: for (int i=0; i<nS1; i++) { db1[i][0] *= MC; // momentum term for (int k=0; k<Q; k++) { db1[i][0] += learnRate * (1-MC) * d1[i][k]; } } // dW2: for (int i=0; i<nS2; i++) { for (int j=0; j<nS1; j++) { dW2[i][j] *= MC; // momentum term for (int k=0; k<Q; k++) { dW2[i][j] += learnRate * (1-MC) * d2[i][k] * out1[j][k]; } } } // db2: for (int i=0; i<nS2; i++) { db2[i][0] *= MC; for (int k=0; k<Q; k++) { db2[i][0] += learnRate * (1-MC) * d2[i][k]; } } // dW3: for (int i=0; i<nS3; i++) { for (int j=0; j<nS2; j++) { dW3[i][j] *= MC; // momentum term for (int k=0; k<Q; k++) { dW3[i][j] += learnRate * (1-MC) * d3[i][k] * out2[j][k]; } } } // db3: for (int i=0; i<nS3; i++) { db3[i][0] *= MC; for (int k=0; k<Q; k++) { db3[i][0] += learnRate * (1-MC) * d3[i][k]; } } // Add dW and db matrices to W and b to get new_W and new_b MC=moment; for (int i=0; i<nS1; i++) { new_b1[i][0] = b1[i][0] + db1[i][0]; for (int j=0; j<rawDim; j++) { new_W1[i][j] = wmx1[i][j] + dW1[i][j]; } } for (int i=0; i<nS2; i++) { new_b2[i][0] = b2[i][0] + db2[i][0]; for (int j=0; j<nS1; j++) { new_W2[i][j] = wmx2[i][j] + dW2[i][j]; } } for (int i=0; i<nS3; i++) { new_b3[i][0] = b3[i][0] + db3[i][0]; for (int j=0; j<nS2; j++) { new_W3[i][j] = wmx3[i][j] + dW3[i][j]; } } // PRESENTATION PHASE ip = multiply(new_W1,inputs); new_out1 = logsig(ip,new_b1); ip = multiply(new_W2,new_out1); new_out2 = logsig(ip,new_b2); ip = multiply(new_W3,new_out2); new_out3 = logsig(ip,new_b3); //Compute errors matrix and sum-squared error (SSE). new_SSE=0; for (int i=0; i<nS3; i++) { for (int j=0; j<Q; j++) { new_errors[i][j] = targets[i][j] - new_out3[i][j]; new_SSE += new_errors[i][j]*new_errors[i][j]; } } // MOMENTUM AND ADAPTIVE LEARNING RATE PHASE if (new_SSE > SSE*error) { learnRate *= decLR; MC=0; // kill momentum off } else { if (new_SSE < SSE) { learnRate *= incLR; } //Rotate pointers between old and new matrices swap = wmx1; wmx1 = new_W1; new_W1 = swap; swap = wmx2; wmx2 = new_W2; new_W2 = swap; swap = wmx3; wmx3 = new_W3; new_W3 = swap; swap = b1; b1 = new_b1; new_b1 = swap; swap = b2; b2 = new_b2; new_b2 = swap; swap = b3; b3 = new_b3; new_b3 = swap; swap = out1; out1 = new_out1; new_out1 = swap;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -