⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 multiclassneuralnetwork.java.svn-base

📁 AI大作业实现神经网络的小程序
💻 SVN-BASE
字号:
package neuralnetworks;

import java.util.Random;

import perceptron.Perceptron;
import perceptron.SigmodPerceptron;

public class MultiClassNeuralNetwork extends NeuralNetwork {

	public MultiClassNeuralNetwork() {

	}

	public MultiClassNeuralNetwork(int nl, int nin, int nh, double lrate, int no) {
		nLayer = nl;
		nInput = nin;
		nHidden = nh;
		this.nOutput = no;

		perceptron = new Perceptron[nLayer][];
		for (int i = 0; i < nLayer; i++) {
			if (i == nl - 1)
				perceptron[i] = new SigmodPerceptron[nOutput];
			else
				perceptron[i] = new SigmodPerceptron[nHidden];
			for (int j = 0; j < perceptron[i].length; j++)
				if (i == 0)
					perceptron[i][j] = new SigmodPerceptron(i, j, nInput,
							lrate, (new Random()).nextDouble() - 0.5);
				else
					perceptron[i][j] = new SigmodPerceptron(i, j, nHidden,
							lrate, (new Random()).nextDouble() - 0.5);
		}
		for (int i = 0; i < nLayer; i++)
			for (int j = 0; j < perceptron[i].length; j++) {
				Perceptron[] pre = null, post = null;
				if (i != 0)
					pre = perceptron[i - 1];
				if (i != nl - 1)
					post = perceptron[i + 1];
				perceptron[i][j].link(pre, post);
			}
	}

	public double[] classifyVec(double[] input) {
		double[] in = input, out = null;
		for (int i = 0; i < nLayer; i++) {
			out = new double[perceptron[i].length];

			for (int j = 0; j < perceptron[i].length; j++) {
				perceptron[i][j].in(in);
				out[j] = perceptron[i][j].out();
			}
			in = out;
		}
		return out;
	}

	public double classify(double[] input) {
		double[] out = classifyVec(input);
		int bj = 0;
		double best = -1;
		for (int i = 0; i < out.length; i++)
			if (out[i] > best) {
				best = out[i];
				bj = i;
			}
		return bj;
	}

	public void train(double input[], double d[]) {
		classify(input);
		for (int j = 0; j < perceptron[nLayer - 1].length; j++)
			perceptron[nLayer - 1][j].updata(d[j]);
		for (int i = nLayer - 2; i >= 0; i--) {
			for (int j = 0; j < perceptron[i].length; j++)
				perceptron[i][j].updata(0);
		}
	}

	public void train(double input[], int d) {
		classify(input);
		for (int j = 0; j < perceptron[nLayer - 1].length; j++)
			perceptron[nLayer - 1][j].updata(d == j ? 1 : 0);
		for (int i = nLayer - 2; i >= 0; i--) {
			for (int j = 0; j < perceptron[i].length; j++)
				perceptron[i][j].updata(0);
		}
	}

	public void train(double input[], double d) {
		train(input, ((Double) d).intValue());
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -