📄 multiclassneuralnetwork.java.svn-base
字号:
package neuralnetworks;
import java.util.Random;
import perceptron.Perceptron;
import perceptron.SigmodPerceptron;
public class MultiClassNeuralNetwork extends NeuralNetwork {
public MultiClassNeuralNetwork() {
}
public MultiClassNeuralNetwork(int nl, int nin, int nh, double lrate, int no) {
nLayer = nl;
nInput = nin;
nHidden = nh;
this.nOutput = no;
perceptron = new Perceptron[nLayer][];
for (int i = 0; i < nLayer; i++) {
if (i == nl - 1)
perceptron[i] = new SigmodPerceptron[nOutput];
else
perceptron[i] = new SigmodPerceptron[nHidden];
for (int j = 0; j < perceptron[i].length; j++)
if (i == 0)
perceptron[i][j] = new SigmodPerceptron(i, j, nInput,
lrate, (new Random()).nextDouble() - 0.5);
else
perceptron[i][j] = new SigmodPerceptron(i, j, nHidden,
lrate, (new Random()).nextDouble() - 0.5);
}
for (int i = 0; i < nLayer; i++)
for (int j = 0; j < perceptron[i].length; j++) {
Perceptron[] pre = null, post = null;
if (i != 0)
pre = perceptron[i - 1];
if (i != nl - 1)
post = perceptron[i + 1];
perceptron[i][j].link(pre, post);
}
}
public double[] classifyVec(double[] input) {
double[] in = input, out = null;
for (int i = 0; i < nLayer; i++) {
out = new double[perceptron[i].length];
for (int j = 0; j < perceptron[i].length; j++) {
perceptron[i][j].in(in);
out[j] = perceptron[i][j].out();
}
in = out;
}
return out;
}
public double classify(double[] input) {
double[] out = classifyVec(input);
int bj = 0;
double best = -1;
for (int i = 0; i < out.length; i++)
if (out[i] > best) {
best = out[i];
bj = i;
}
return bj;
}
public void train(double input[], double d[]) {
classify(input);
for (int j = 0; j < perceptron[nLayer - 1].length; j++)
perceptron[nLayer - 1][j].updata(d[j]);
for (int i = nLayer - 2; i >= 0; i--) {
for (int j = 0; j < perceptron[i].length; j++)
perceptron[i][j].updata(0);
}
}
public void train(double input[], int d) {
classify(input);
for (int j = 0; j < perceptron[nLayer - 1].length; j++)
perceptron[nLayer - 1][j].updata(d == j ? 1 : 0);
for (int i = nLayer - 2; i >= 0; i--) {
for (int j = 0; j < perceptron[i].length; j++)
perceptron[i][j].updata(0);
}
}
public void train(double input[], double d) {
train(input, ((Double) d).intValue());
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -