⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bpann.java

📁 人工神经网络BP算法
💻 JAVA
字号:
package BPAnn;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class BPAnn {

	private double wI2H[][] = null; // weigths for Input(row)->Hidden(col)

	private double wH2O[][] = null; // weights for Hidden(row)->Output(col)

	private double oHidden[] = null; // Output values for hidden nodes

	private double oOutput[] = null; // Output values for output nodes

	private int nIn; // n input nodes

	private int nOut; // n output nodes

	private int nHidden; // n hidden nodes

	private double learnRate; // learning rate

	private double momentum; // momentum rate

	private boolean sigmoidOutput; // whether the output nodes are sigmoidnodes

	private int nLoop; // current training loops

	private double err; // |Sum(err of each case)|/(nCase * nOutput)

	public String toString() {
		StringBuilder sb = new StringBuilder();

		sb.append("[BPANN] :");
		sb.append("\n");

		sb.append(String.format("nLoops = %5d ", nLoop));
		sb.append(String.format(" ; AveABSError = %.12f", err));
		sb.append(String.format(" ; learnRate = %.12f", learnRate));
		sb.append(String.format(" ; momentum = %.12f", momentum));
		sb.append(String.format(" ; sigmoidOutput = %s", sigmoidOutput));
		sb.append("\n");

		sb.append("wI2H:");
		sb.append("\n");
		for (int i = 0; i < this.wI2H.length; i++) {
			for (int j = 0; j < this.wI2H[0].length; j++)
				sb.append(String.format("%10.5f ", wI2H[i][j]));
			sb.append("\n");
		}

		sb.append("wH2O:");
		sb.append("\n");
		for (int i = 0; i < this.wH2O.length; i++) {
			for (int j = 0; j < this.wH2O[0].length; j++)
				sb.append(String.format("%10.5f ", wH2O[i][j]));
			sb.append("\n");
		}
		sb.append("\n");
		return sb.toString();
	}

	public BPAnn(int nIn, int nOut, int nHidden, double learnRate,
			double momentum, boolean sigmoidOutput) {

		this.nIn = nIn;
		this.nOut = nOut;
		this.nHidden =  nHidden=(int)Math.sqrt(nIn+nOut)+5;
     
		wI2H = new double[nIn][nHidden];
		wH2O = new double[nHidden][nOut];
		oHidden = new double[nHidden];
		oOutput = new double[nOut];

		this.learnRate = learnRate;
		this.momentum = momentum;
		this.sigmoidOutput = sigmoidOutput;

		this.nLoop = 0;
		this.err = 0.0;
	}

	public void initWeights(double min, double max) {

		for (int i = 0; i < this.nIn; i++) {
			for (int j = 0; j < this.nHidden; j++) {
				wI2H[i][j] = ANNUtil.randDouble(min, max);
			}
		}

		for (int i = 0; i < this.nHidden; i++) {
			for (int j = 0; j < this.nOut; j++) {
				wH2O[i][j] = ANNUtil.randDouble(min, max);
			}
		}
	}

	public void training(List<ANNData> trainingExamples,
			boolean terminateByLoops, int maxLoop, double maxErr,
			int prtInterval) {
		
		if (trainingExamples.size() == 0)
			return;

		if (terminateByLoops) {// Terminate after the predefined loops of
			// training
			for (this.nLoop = 0; nLoop < maxLoop; nLoop++) {
				System.out.println(nLoop);
				trainingOnce(trainingExamples);
			
				if (nLoop % prtInterval == 0) {
					System.out.printf("Loop: %8d  AverageABSError: %.16f\n",
							nLoop, this.err);
				}
			}
		} else {// Terminate when the error is smaller than the predefined value
			do {
				trainingOnce(trainingExamples);
			} while (this.err > maxErr);
		}
	}

	void trainingOnce(List<ANNData> trainingExamples) {
		this.err = 0.0;
		for (ANNData trainingExample : trainingExamples) {
			//System.out.println(nLoop);
			computeValues(trainingExample);
			updateWeights(trainingExample);
		}
		this.err = Math.sqrt(this.err / (trainingExamples.size() * this.nOut));
	}

	void computeValues(ANNData data) {

		double[] input = data.getInput();

		for (int i = 0; i < this.nHidden; i++) {
			this.oHidden[i] = 0.0;
			for (int j = 0; j < this.nIn; j++) {
				this.oHidden[i] += input[j] * this.wI2H[j][i];
			}
			this.oHidden[i] = ANNUtil.sigmoid(this.oHidden[i]);// Use sigmode
			// for hidden
			// nodes
		}
		for (int i = 0; i < this.nOut; i++) {
			this.oOutput[i] = 0.0;
			for (int j = 0; j < this.nHidden; j++) {
				this.oOutput[i] += this.oHidden[j] * this.wH2O[j][i];
			}
			if (this.sigmoidOutput)
				this.oOutput[i] = ANNUtil.sigmoid(this.oOutput[i]);
		}

		// Accumulate the error of this data row
		double[] expexctedOutput = data.getExpectedOutput();
		for (int i = 0; i < this.nOut; i++)
			this.err += Math.pow(this.oOutput[i] - expexctedOutput[i], 2.0);

	}

	public double[] predict(ANNData annData) {
		double[] input = annData.getInput();
		double[] oH = new double[this.nHidden];
		double[] oO = new double[this.nOut];
		for (int i = 0; i < this.nHidden; i++) {
			oH[i] = 0.0;
			for (int j = 0; j < this.nIn; j++) {
				oH[i] += input[j] * this.wI2H[j][i];
			}
			oH[i] = ANNUtil.sigmoid(oH[i]);// Use sigmode
		}
		for (int i = 0; i < this.nOut; i++) {
			oO[i] = 0.0;
			for (int j = 0; j < this.nHidden; j++) {
				oO[i] += oO[j] * this.wH2O[j][i];
			}
			if (this.sigmoidOutput)
				oO[i] = ANNUtil.sigmoid(oO[i]);// Use
		}
		return oO;
	}

	void updateWeights(ANNData data) {

		double[] input = data.getInput();
		double[] expexctedOutput = data.getExpectedOutput();

		double[] deltaOutput = new double[this.nOut];
		double[] deltaHidden = new double[this.nHidden];

		// Compute Error for output of output nodes
		for (int i = 0; i < this.nOut; i++) {
			if (this.sigmoidOutput) {
				deltaOutput[i] = (this.oOutput[i]) * (1 - this.oOutput[i])
						* (expexctedOutput[i] - this.oOutput[i]);
			} else {
				deltaOutput[i] = expexctedOutput[i] - this.oOutput[i];
			}
		}

		// Compute Error for output of hidden nodes
		for (int i = 0; i < this.nHidden; i++) {

			double sum = 0.0;
			for (int j = 0; j < this.nOut; j++) {
				sum += deltaOutput[j] * this.wH2O[i][j];
			}
			deltaHidden[i] = this.oHidden[i] * (1.0 - this.oHidden[i]) * sum;
		}

		// Update weights of Hidden to Output nodes
		for (int i = 0; i < this.nOut; i++) {
			for (int j = 0; j < this.nHidden; j++) {
				this.wH2O[j][i] *= 1.0 + this.momentum;
				this.wH2O[j][i] += this.learnRate * deltaOutput[i]
						* this.oHidden[j];
			}
		}

		// Update weights of Input to Hidden nodes
		for (int i = 0; i < this.nHidden; i++) {
			for (int j = 0; j < this.nIn; j++) {
				this.wI2H[j][i] *= 1.0 + this.momentum;
				this.wI2H[j][i] += this.learnRate * deltaHidden[i] * input[j];
			}
		}
	}

	public void test(List<ANNData> testData) {
		if (testData.size() == 0)
			return;

		this.err = 0.0;
		for (ANNData data : testData) {
			computeValues(data);
			for (int i = 0; i < this.nOut; i++)
				System.out.printf("Expected: %12.6f, Actual: %12.6f ;", data
						.getExpectedOutput()[i], this.oOutput[i]);
			System.out.println();
		}
		this.err = Math.sqrt(this.err / (testData.size() * this.nOut));
	}

	public static void main(String[] args) {

		List<ANNData> trainingExamples = getTrainingExamples();
		if (trainingExamples.size() == 0)
			return;
		printExamples(trainingExamples);

		int nIn = trainingExamples.get(0).getInput().length;
		int nOut = trainingExamples.get(0).getExpectedOutput().length;
		int nHidden = 2;
		double learnRate = 0.6;
		double momentum = 0.00002;
		boolean sigmoidOutput = false;

		double minInitW = -1.0;
		double maxInitW = 1.0;
		int maxLoop = 70000;
		double maxErr = 0.01;
		int prtInterval = (maxLoop < 10) ? 1 : (maxLoop / 50);// 2;

		BPAnn bpann = new BPAnn(nIn, nOut, nHidden, learnRate, momentum,
				sigmoidOutput);
		bpann.initWeights(minInitW, maxInitW);

		System.out.println("Before training:\n" + bpann);
		bpann.training(trainingExamples, true, maxLoop, maxErr, prtInterval);
		System.out.println("After training:\n" + bpann);

		bpann.test(trainingExamples);
		// List<Data> testData = getTestData();
		// if(testData.size() == 0) return;
		// bpann.test(testData);
		// System.out.println(bpann);
	}

	static List<ANNData> getTrainingExamples() {

		List<ANNData> examples = new ArrayList<ANNData>();
		ANNData example;
		example = new ANNData(2, 1);
		example.setInput(new double[] { 0.0, 0.0 });
		example.setExpectedOutput(new double[] { 0.0 });
		examples.add(example);
		example = new ANNData(2, 1);
		example.setInput(new double[] { 1.0, 0.0 });
		example.setExpectedOutput(new double[] { 1.0 });
		examples.add(example);
		example = new ANNData(2, 1);
		example.setInput(new double[] { 1.0, 1.0 });
		example.setExpectedOutput(new double[] { 1.0 });
		examples.add(example);
		example = new ANNData(2, 1);
		example.setInput(new double[] { 0.0, 1.0 });
		example.setExpectedOutput(new double[] { 1.0 });
		examples.add(example);
		example = new ANNData(2, 1);
		example.setInput(new double[] { 1.0, -1.0 });
		example.setExpectedOutput(new double[] { 0.0 });
		examples.add(example);
		example = new ANNData(2, 1);
		example.setInput(new double[] { 0.0, -1.0 });
		example.setExpectedOutput(new double[] { 1.0 });
		examples.add(example);
		example = new ANNData(2, 1);
		example.setInput(new double[] { -1.0, -1.0 });
		example.setExpectedOutput(new double[] { 1.0 });
		examples.add(example);
		example = new ANNData(2, 1);
		example.setInput(new double[] { -1.0, 0.0 });
		example.setExpectedOutput(new double[] { 1.0 });
		examples.add(example);
		example = new ANNData(2, 1);
		example.setInput(new double[] { -1.0, 1.0 });
		example.setExpectedOutput(new double[] { 0.0 });
		examples.add(example);

		return examples;
	}

	static void printExamples(List<ANNData> examples) {
		System.out.println("exmples :");
		for (ANNData example : examples)
			System.out.println(example);
	}

	static List<ANNData> getTestData() {
		List<ANNData> testData = new ArrayList<ANNData>();
		// TODO
		return testData;
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -