📄 neuralnetwork.java.svn-base
字号:
package neuralnetworks;
import java.util.Random;
import java.io.*;
import perceptron.*;
public class NeuralNetwork {
/**
* The number of layers.
*/
int nLayer;
/**
* The number of input.
*/
int nInput, nHidden, nOutput;
Perceptron perceptron[][];
public NeuralNetwork() {
}
/**
* Simple initialize();
*
* @param nl
* @param nin
* @param nh
* @param lrate
*/
public NeuralNetwork(int nl, int nin, int nh, double lrate) {
nLayer = nl;
nInput = nin;
nHidden = nh;
nOutput = 1;
perceptron = new Perceptron[nLayer][];
for (int i = 0; i < nLayer; i++) {
if (i == nl - 1)
perceptron[i] = new SigmodPerceptron[nOutput];
else
perceptron[i] = new SigmodPerceptron[nHidden];
for (int j = 0; j < perceptron[i].length; j++)
if (i == 0)
perceptron[i][j] = new SigmodPerceptron(i, j, nInput,
lrate, (new Random()).nextDouble() - 0.5);
else
perceptron[i][j] = new SigmodPerceptron(i, j, nHidden,
lrate, (new Random()).nextDouble() - 0.5);
}
for (int i = 0; i < nLayer; i++)
for (int j = 0; j < perceptron[i].length; j++) {
Perceptron[] pre = null, post = null;
if (i != 0)
pre = perceptron[i - 1];
if (i != nl - 1)
post = perceptron[i + 1];
perceptron[i][j].link(pre, post);
}
}
public NeuralNetwork(String fileName) {
loadNet(fileName);
}
public void showNet() {
System.out
.println("****************************************************************************************");
for (int i = 0; i < perceptron.length; i++) {
for (int j = 0; j < perceptron[i].length; j++)
System.out.println(perceptron[i][j]);
}
System.out
.println("****************************************************************************************");
}
public void initWeight(int i, int j, double w[]) {
perceptron[i][j].initWeight(w);
}
public void saveNet(String fileName) {
try {
FileOutputStream fis = new FileOutputStream(fileName);
DataOutputStream dos = new DataOutputStream(fis);
dos.writeInt(nLayer);
dos.writeInt(nInput);
dos.writeInt(nHidden);
dos.writeInt(nOutput);
dos.writeUTF((perceptron[0][0].getClass().toString()));
for (int i = 0; i < perceptron.length; i++) {
for (int j = 0; j < perceptron[i].length; j++)
perceptron[i][j].save(fis);
}
fis.close();
} catch (Exception e) {
// TODO 自动生成 catch 块
e.printStackTrace();
}
}
public void loadNet(String fileName) {
try {
FileInputStream fis = new FileInputStream(fileName);
DataInputStream dis = new DataInputStream(fis);
nLayer = dis.readInt();
int nl = nLayer;
nInput = dis.readInt();
nHidden = dis.readInt();
nOutput = dis.readInt();
String classString = dis.readUTF();
Class[] cls = { InputStream.class };
perceptron = new Perceptron[nLayer][];
for (int i = 0; i < nLayer; i++) {
if (i == nl - 1)
perceptron[i] = new SigmodPerceptron[nOutput];
else
perceptron[i] = new SigmodPerceptron[nHidden];
for (int j = 0; j < perceptron[i].length; j++) {
// perceptron[i][j] = (Perceptron) ClassLoader
// .getSystemClassLoader().loadClass(classString)
// .getConstructor(cls).newInstance(fis);
perceptron[i][j] = new SigmodPerceptron(fis);
}
}
for (int i = 0; i < nLayer; i++)
for (int j = 0; j < perceptron[i].length; j++) {
Perceptron[] pre = null, post = null;
if (i != 0)
pre = perceptron[i - 1];
if (i != nl - 1)
post = perceptron[i + 1];
perceptron[i][j].link(pre, post);
}
fis.close();
} catch (Exception e) {
e.printStackTrace();
}
}
public void saveNet() {
saveNet("AI_NeuralNetworks");
}
public void loadNet() {
loadNet("AI_NeuralNetworks");
}
public double classify(double[] input) {
double f = 0;
double[] in = input, out;
for (int i = 0; i < nLayer - 1; i++) {
out = new double[perceptron[i].length];
for (int j = 0; j < perceptron[i].length; j++) {
perceptron[i][j].in(in);
out[j] = perceptron[i][j].out();
}
in = out;
}
perceptron[nLayer - 1][0].in(in);
f = perceptron[nLayer - 1][0].out();
return f;
}
public void train(double[] input, double d) {
classify(input);
for (int i = nLayer - 1; i >= 0; i--) {
for (int j = 0; j < perceptron[i].length; j++)
perceptron[i][j].updata(d);
}
}
/**
* Test Main function
*
* @param args
*/
public static void play1() {
double trainSet[][] = { { 1, 0, 1 }, { 0, 0, 1 }, { 0, 1, 1 },
{ 1, 1, 1 } };
double d[] = { 0, 1, 0, 1 };
NeuralNetwork nn = new NeuralNetwork(3, 3, 3, 0.1);
// MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork(3, 3, 2, 1,
// 2);
// MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork();
// nn.loadNet();
System.out.println("old:" + nn.classify(trainSet[0]));
System.out.println("old:" + nn.classify(trainSet[1]));
System.out.println("old:" + nn.classify(trainSet[2]));
System.out.println("old:" + nn.classify(trainSet[3]));
for (int tms = 0; tms < 100000; tms++) {
nn.train(trainSet[0], d[0]);
nn.train(trainSet[1], d[1]);
nn.train(trainSet[2], d[2]);
nn.train(trainSet[3], d[3]);
// nn.showNet();
}
for (int i = 0; i < 4; i++) {
// double[] out = nn.classifyVec(trainSet[i]);
// System.out.println(nn.classify(trainSet[i]) + ": " + out[0] + " ,
// "
// + out[1]);
System.out.println(nn.classify(trainSet[i]));
}
// nn.saveNet();
// NeuralNetwork ne = new NeuralNetwork();
}
public static void play2() {
double trainSet[][] = {
{ 1.8705656e-002, -3.3776245e-002, -1.7389711e-002,
8.7999413e-003, -1.7687717e-002, 2.7313823e-002 },
{ 1.9415015e-002, -1.7825606e-002, 8.9698163e-003,
6.1297178e-003, 2.8228115e-002, -1.2865974e-002 },
{ 1.8308895e-002, 9.2715043e-003, 2.0354729e-002,
2.1904986e-002, -1.8979602e-002, -2.9368540e-002 },
{ 1.9026230e-002, -1.5388940e-002, 6.0216560e-003,
-7.4243494e-003, -2.3993461e-002, 2.2068728e-002 },
{ 1.7064405e-002, 1.4599663e-002, 2.2824530e-003,
5.5328306e-003, 2.3268452e-002, 7.8431617e-003 } };
// double d[] = { 3, 2, 5, 4, 1 };
double d[] = { 2, 1, 4, 3, 0 };
// NeuralNetwork nn = new NeuralNetwork(3, 3, 4, 1);
// MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork(3, 6, 3,
// 0.1,
// 6);
MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork();
nn.loadNet();
for (int i = 0; i < trainSet.length; i++) {
System.out.println("old:" + nn.classify(trainSet[i]));
}
for (int tms = 0; tms < 200000; tms++) {
for (int i = 0; i < trainSet.length; i++) {
nn.train(trainSet[i], d[i]);
// nn.showNet();
}
}
for (int i = 0; i < trainSet.length; i++) {
double[] out = nn.classifyVec(trainSet[i]);
System.out.print(nn.classify(trainSet[i]) + ": ");
for (int j = 0; j < 5; j++)
System.out.print(out[j] + " ");
System.out.println();
}
nn.saveNet("AI_HW_MOD");
// NeuralNetwork ne = new NeuralNetwork();
}
public static void main(String args[]) {
// play1();
play2();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -