⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuralnetwork.java.svn-base

📁 AI大作业实现神经网络的小程序
💻 SVN-BASE
字号:
package neuralnetworks;

import java.util.Random;
import java.io.*;

import perceptron.*;

public class NeuralNetwork {
	/**
	 * The number of layers.
	 */
	int nLayer;

	/**
	 * The number of input.
	 */
	int nInput, nHidden, nOutput;

	Perceptron perceptron[][];

	public NeuralNetwork() {

	}

	/**
	 * Simple initialize();
	 * 
	 * @param nl
	 * @param nin
	 * @param nh
	 * @param lrate
	 */
	public NeuralNetwork(int nl, int nin, int nh, double lrate) {
		nLayer = nl;
		nInput = nin;
		nHidden = nh;
		nOutput = 1;

		perceptron = new Perceptron[nLayer][];
		for (int i = 0; i < nLayer; i++) {
			if (i == nl - 1)
				perceptron[i] = new SigmodPerceptron[nOutput];
			else
				perceptron[i] = new SigmodPerceptron[nHidden];
			for (int j = 0; j < perceptron[i].length; j++)
				if (i == 0)
					perceptron[i][j] = new SigmodPerceptron(i, j, nInput,
							lrate, (new Random()).nextDouble() - 0.5);
				else
					perceptron[i][j] = new SigmodPerceptron(i, j, nHidden,
							lrate, (new Random()).nextDouble() - 0.5);
		}
		for (int i = 0; i < nLayer; i++)
			for (int j = 0; j < perceptron[i].length; j++) {
				Perceptron[] pre = null, post = null;
				if (i != 0)
					pre = perceptron[i - 1];
				if (i != nl - 1)
					post = perceptron[i + 1];
				perceptron[i][j].link(pre, post);
			}
	}

	public NeuralNetwork(String fileName) {
		loadNet(fileName);
	}

	public void showNet() {
		System.out
				.println("****************************************************************************************");
		for (int i = 0; i < perceptron.length; i++) {
			for (int j = 0; j < perceptron[i].length; j++)
				System.out.println(perceptron[i][j]);
		}
		System.out
				.println("****************************************************************************************");
	}

	public void initWeight(int i, int j, double w[]) {
		perceptron[i][j].initWeight(w);
	}

	public void saveNet(String fileName) {
		try {
			FileOutputStream fis = new FileOutputStream(fileName);
			DataOutputStream dos = new DataOutputStream(fis);

			dos.writeInt(nLayer);
			dos.writeInt(nInput);
			dos.writeInt(nHidden);
			dos.writeInt(nOutput);
			dos.writeUTF((perceptron[0][0].getClass().toString()));

			for (int i = 0; i < perceptron.length; i++) {
				for (int j = 0; j < perceptron[i].length; j++)
					perceptron[i][j].save(fis);
			}
			fis.close();

		} catch (Exception e) {
			// TODO 自动生成 catch 块
			e.printStackTrace();
		}
	}

	public void loadNet(String fileName) {
		try {
			FileInputStream fis = new FileInputStream(fileName);

			DataInputStream dis = new DataInputStream(fis);

			nLayer = dis.readInt();
			int nl = nLayer;
			nInput = dis.readInt();
			nHidden = dis.readInt();
			nOutput = dis.readInt();

			String classString = dis.readUTF();

			Class[] cls = { InputStream.class };

			perceptron = new Perceptron[nLayer][];
			for (int i = 0; i < nLayer; i++) {
				if (i == nl - 1)
					perceptron[i] = new SigmodPerceptron[nOutput];
				else
					perceptron[i] = new SigmodPerceptron[nHidden];
				for (int j = 0; j < perceptron[i].length; j++) {
					// perceptron[i][j] = (Perceptron) ClassLoader
					// .getSystemClassLoader().loadClass(classString)
					// .getConstructor(cls).newInstance(fis);

					perceptron[i][j] = new SigmodPerceptron(fis);
				}
			}
			for (int i = 0; i < nLayer; i++)
				for (int j = 0; j < perceptron[i].length; j++) {
					Perceptron[] pre = null, post = null;
					if (i != 0)
						pre = perceptron[i - 1];
					if (i != nl - 1)
						post = perceptron[i + 1];
					perceptron[i][j].link(pre, post);
				}

			fis.close();

		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	public void saveNet() {
		saveNet("AI_NeuralNetworks");
	}

	public void loadNet() {
		loadNet("AI_NeuralNetworks");
	}

	public double classify(double[] input) {
		double f = 0;
		double[] in = input, out;
		for (int i = 0; i < nLayer - 1; i++) {
			out = new double[perceptron[i].length];

			for (int j = 0; j < perceptron[i].length; j++) {
				perceptron[i][j].in(in);
				out[j] = perceptron[i][j].out();
			}
			in = out;
		}
		perceptron[nLayer - 1][0].in(in);
		f = perceptron[nLayer - 1][0].out();
		return f;
	}

	public void train(double[] input, double d) {
		classify(input);
		for (int i = nLayer - 1; i >= 0; i--) {
			for (int j = 0; j < perceptron[i].length; j++)
				perceptron[i][j].updata(d);
		}
	}

	/**
	 * Test Main function
	 * 
	 * @param args
	 */
	public static void play1() {
		double trainSet[][] = { { 1, 0, 1 }, { 0, 0, 1 }, { 0, 1, 1 },
				{ 1, 1, 1 } };
		double d[] = { 0, 1, 0, 1 };

		NeuralNetwork nn = new NeuralNetwork(3, 3, 3, 0.1);
		// MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork(3, 3, 2, 1,
		// 2);
		// MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork();
		// nn.loadNet();

		System.out.println("old:" + nn.classify(trainSet[0]));
		System.out.println("old:" + nn.classify(trainSet[1]));
		System.out.println("old:" + nn.classify(trainSet[2]));
		System.out.println("old:" + nn.classify(trainSet[3]));
		for (int tms = 0; tms < 100000; tms++) {
			nn.train(trainSet[0], d[0]);
			nn.train(trainSet[1], d[1]);
			nn.train(trainSet[2], d[2]);
			nn.train(trainSet[3], d[3]);
			// nn.showNet();
		}
		for (int i = 0; i < 4; i++) {
			// double[] out = nn.classifyVec(trainSet[i]);
			// System.out.println(nn.classify(trainSet[i]) + ": " + out[0] + " ,
			// "
			// + out[1]);
			System.out.println(nn.classify(trainSet[i]));
		}

		// nn.saveNet();

		// NeuralNetwork ne = new NeuralNetwork();
	}

	public static void play2() {
		double trainSet[][] = {
				{ 1.8705656e-002, -3.3776245e-002, -1.7389711e-002,
						8.7999413e-003, -1.7687717e-002, 2.7313823e-002 },
				{ 1.9415015e-002, -1.7825606e-002, 8.9698163e-003,
						6.1297178e-003, 2.8228115e-002, -1.2865974e-002 },
				{ 1.8308895e-002, 9.2715043e-003, 2.0354729e-002,
						2.1904986e-002, -1.8979602e-002, -2.9368540e-002 },
				{ 1.9026230e-002, -1.5388940e-002, 6.0216560e-003,
						-7.4243494e-003, -2.3993461e-002, 2.2068728e-002 },
				{ 1.7064405e-002, 1.4599663e-002, 2.2824530e-003,
						5.5328306e-003, 2.3268452e-002, 7.8431617e-003 } };

		// double d[] = { 3, 2, 5, 4, 1 };
		double d[] = { 2, 1, 4, 3, 0 };

		// NeuralNetwork nn = new NeuralNetwork(3, 3, 4, 1);

		// MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork(3, 6, 3,
		// 0.1,
		// 6);

		MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork();
		nn.loadNet();

		for (int i = 0; i < trainSet.length; i++) {
			System.out.println("old:" + nn.classify(trainSet[i]));

		}
		for (int tms = 0; tms < 200000; tms++) {
			for (int i = 0; i < trainSet.length; i++) {
				nn.train(trainSet[i], d[i]);
				// nn.showNet();
			}
		}
		for (int i = 0; i < trainSet.length; i++) {
			double[] out = nn.classifyVec(trainSet[i]);
			System.out.print(nn.classify(trainSet[i]) + ": ");
			for (int j = 0; j < 5; j++)
				System.out.print(out[j] + "  ");
			System.out.println();
		}

		nn.saveNet("AI_HW_MOD");

		// NeuralNetwork ne = new NeuralNetwork();
	}

	public static void main(String args[]) {
		// play1();
		play2();
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -