📄 bpann.java
字号:
package BPAnn;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class BPAnn {
private double wI2H[][] = null; // weigths for Input(row)->Hidden(col)
private double wH2O[][] = null; // weights for Hidden(row)->Output(col)
private double oHidden[] = null; // Output values for hidden nodes
private double oOutput[] = null; // Output values for output nodes
private int nIn; // n input nodes
private int nOut; // n output nodes
private int nHidden; // n hidden nodes
private double learnRate; // learning rate
private double momentum; // momentum rate
private boolean sigmoidOutput; // whether the output nodes are sigmoidnodes
private int nLoop; // current training loops
private double err; // |Sum(err of each case)|/(nCase * nOutput)
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[BPANN] :");
sb.append("\n");
sb.append(String.format("nLoops = %5d ", nLoop));
sb.append(String.format(" ; AveABSError = %.12f", err));
sb.append(String.format(" ; learnRate = %.12f", learnRate));
sb.append(String.format(" ; momentum = %.12f", momentum));
sb.append(String.format(" ; sigmoidOutput = %s", sigmoidOutput));
sb.append("\n");
sb.append("wI2H:");
sb.append("\n");
for (int i = 0; i < this.wI2H.length; i++) {
for (int j = 0; j < this.wI2H[0].length; j++)
sb.append(String.format("%10.5f ", wI2H[i][j]));
sb.append("\n");
}
sb.append("wH2O:");
sb.append("\n");
for (int i = 0; i < this.wH2O.length; i++) {
for (int j = 0; j < this.wH2O[0].length; j++)
sb.append(String.format("%10.5f ", wH2O[i][j]));
sb.append("\n");
}
sb.append("\n");
return sb.toString();
}
public BPAnn(int nIn, int nOut, int nHidden, double learnRate,
double momentum, boolean sigmoidOutput) {
this.nIn = nIn;
this.nOut = nOut;
this.nHidden = nHidden=(int)Math.sqrt(nIn+nOut)+5;
wI2H = new double[nIn][nHidden];
wH2O = new double[nHidden][nOut];
oHidden = new double[nHidden];
oOutput = new double[nOut];
this.learnRate = learnRate;
this.momentum = momentum;
this.sigmoidOutput = sigmoidOutput;
this.nLoop = 0;
this.err = 0.0;
}
public void initWeights(double min, double max) {
for (int i = 0; i < this.nIn; i++) {
for (int j = 0; j < this.nHidden; j++) {
wI2H[i][j] = ANNUtil.randDouble(min, max);
}
}
for (int i = 0; i < this.nHidden; i++) {
for (int j = 0; j < this.nOut; j++) {
wH2O[i][j] = ANNUtil.randDouble(min, max);
}
}
}
public void training(List<ANNData> trainingExamples,
boolean terminateByLoops, int maxLoop, double maxErr,
int prtInterval) {
if (trainingExamples.size() == 0)
return;
if (terminateByLoops) {// Terminate after the predefined loops of
// training
for (this.nLoop = 0; nLoop < maxLoop; nLoop++) {
System.out.println(nLoop);
trainingOnce(trainingExamples);
if (nLoop % prtInterval == 0) {
System.out.printf("Loop: %8d AverageABSError: %.16f\n",
nLoop, this.err);
}
}
} else {// Terminate when the error is smaller than the predefined value
do {
trainingOnce(trainingExamples);
} while (this.err > maxErr);
}
}
void trainingOnce(List<ANNData> trainingExamples) {
this.err = 0.0;
for (ANNData trainingExample : trainingExamples) {
//System.out.println(nLoop);
computeValues(trainingExample);
updateWeights(trainingExample);
}
this.err = Math.sqrt(this.err / (trainingExamples.size() * this.nOut));
}
void computeValues(ANNData data) {
double[] input = data.getInput();
for (int i = 0; i < this.nHidden; i++) {
this.oHidden[i] = 0.0;
for (int j = 0; j < this.nIn; j++) {
this.oHidden[i] += input[j] * this.wI2H[j][i];
}
this.oHidden[i] = ANNUtil.sigmoid(this.oHidden[i]);// Use sigmode
// for hidden
// nodes
}
for (int i = 0; i < this.nOut; i++) {
this.oOutput[i] = 0.0;
for (int j = 0; j < this.nHidden; j++) {
this.oOutput[i] += this.oHidden[j] * this.wH2O[j][i];
}
if (this.sigmoidOutput)
this.oOutput[i] = ANNUtil.sigmoid(this.oOutput[i]);
}
// Accumulate the error of this data row
double[] expexctedOutput = data.getExpectedOutput();
for (int i = 0; i < this.nOut; i++)
this.err += Math.pow(this.oOutput[i] - expexctedOutput[i], 2.0);
}
public double[] predict(ANNData annData) {
double[] input = annData.getInput();
double[] oH = new double[this.nHidden];
double[] oO = new double[this.nOut];
for (int i = 0; i < this.nHidden; i++) {
oH[i] = 0.0;
for (int j = 0; j < this.nIn; j++) {
oH[i] += input[j] * this.wI2H[j][i];
}
oH[i] = ANNUtil.sigmoid(oH[i]);// Use sigmode
}
for (int i = 0; i < this.nOut; i++) {
oO[i] = 0.0;
for (int j = 0; j < this.nHidden; j++) {
oO[i] += oO[j] * this.wH2O[j][i];
}
if (this.sigmoidOutput)
oO[i] = ANNUtil.sigmoid(oO[i]);// Use
}
return oO;
}
void updateWeights(ANNData data) {
double[] input = data.getInput();
double[] expexctedOutput = data.getExpectedOutput();
double[] deltaOutput = new double[this.nOut];
double[] deltaHidden = new double[this.nHidden];
// Compute Error for output of output nodes
for (int i = 0; i < this.nOut; i++) {
if (this.sigmoidOutput) {
deltaOutput[i] = (this.oOutput[i]) * (1 - this.oOutput[i])
* (expexctedOutput[i] - this.oOutput[i]);
} else {
deltaOutput[i] = expexctedOutput[i] - this.oOutput[i];
}
}
// Compute Error for output of hidden nodes
for (int i = 0; i < this.nHidden; i++) {
double sum = 0.0;
for (int j = 0; j < this.nOut; j++) {
sum += deltaOutput[j] * this.wH2O[i][j];
}
deltaHidden[i] = this.oHidden[i] * (1.0 - this.oHidden[i]) * sum;
}
// Update weights of Hidden to Output nodes
for (int i = 0; i < this.nOut; i++) {
for (int j = 0; j < this.nHidden; j++) {
this.wH2O[j][i] *= 1.0 + this.momentum;
this.wH2O[j][i] += this.learnRate * deltaOutput[i]
* this.oHidden[j];
}
}
// Update weights of Input to Hidden nodes
for (int i = 0; i < this.nHidden; i++) {
for (int j = 0; j < this.nIn; j++) {
this.wI2H[j][i] *= 1.0 + this.momentum;
this.wI2H[j][i] += this.learnRate * deltaHidden[i] * input[j];
}
}
}
public void test(List<ANNData> testData) {
if (testData.size() == 0)
return;
this.err = 0.0;
for (ANNData data : testData) {
computeValues(data);
for (int i = 0; i < this.nOut; i++)
System.out.printf("Expected: %12.6f, Actual: %12.6f ;", data
.getExpectedOutput()[i], this.oOutput[i]);
System.out.println();
}
this.err = Math.sqrt(this.err / (testData.size() * this.nOut));
}
public static void main(String[] args) {
List<ANNData> trainingExamples = getTrainingExamples();
if (trainingExamples.size() == 0)
return;
printExamples(trainingExamples);
int nIn = trainingExamples.get(0).getInput().length;
int nOut = trainingExamples.get(0).getExpectedOutput().length;
int nHidden = 2;
double learnRate = 0.6;
double momentum = 0.00002;
boolean sigmoidOutput = false;
double minInitW = -1.0;
double maxInitW = 1.0;
int maxLoop = 70000;
double maxErr = 0.01;
int prtInterval = (maxLoop < 10) ? 1 : (maxLoop / 50);// 2;
BPAnn bpann = new BPAnn(nIn, nOut, nHidden, learnRate, momentum,
sigmoidOutput);
bpann.initWeights(minInitW, maxInitW);
System.out.println("Before training:\n" + bpann);
bpann.training(trainingExamples, true, maxLoop, maxErr, prtInterval);
System.out.println("After training:\n" + bpann);
bpann.test(trainingExamples);
// List<Data> testData = getTestData();
// if(testData.size() == 0) return;
// bpann.test(testData);
// System.out.println(bpann);
}
static List<ANNData> getTrainingExamples() {
List<ANNData> examples = new ArrayList<ANNData>();
ANNData example;
example = new ANNData(2, 1);
example.setInput(new double[] { 0.0, 0.0 });
example.setExpectedOutput(new double[] { 0.0 });
examples.add(example);
example = new ANNData(2, 1);
example.setInput(new double[] { 1.0, 0.0 });
example.setExpectedOutput(new double[] { 1.0 });
examples.add(example);
example = new ANNData(2, 1);
example.setInput(new double[] { 1.0, 1.0 });
example.setExpectedOutput(new double[] { 1.0 });
examples.add(example);
example = new ANNData(2, 1);
example.setInput(new double[] { 0.0, 1.0 });
example.setExpectedOutput(new double[] { 1.0 });
examples.add(example);
example = new ANNData(2, 1);
example.setInput(new double[] { 1.0, -1.0 });
example.setExpectedOutput(new double[] { 0.0 });
examples.add(example);
example = new ANNData(2, 1);
example.setInput(new double[] { 0.0, -1.0 });
example.setExpectedOutput(new double[] { 1.0 });
examples.add(example);
example = new ANNData(2, 1);
example.setInput(new double[] { -1.0, -1.0 });
example.setExpectedOutput(new double[] { 1.0 });
examples.add(example);
example = new ANNData(2, 1);
example.setInput(new double[] { -1.0, 0.0 });
example.setExpectedOutput(new double[] { 1.0 });
examples.add(example);
example = new ANNData(2, 1);
example.setInput(new double[] { -1.0, 1.0 });
example.setExpectedOutput(new double[] { 0.0 });
examples.add(example);
return examples;
}
static void printExamples(List<ANNData> examples) {
System.out.println("exmples :");
for (ANNData example : examples)
System.out.println(example);
}
static List<ANNData> getTestData() {
List<ANNData> testData = new ArrayList<ANNData>();
// TODO
return testData;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -