📄 perceptron.java.svn-base
字号:
package perceptron;
import java.util.*;
import java.io.*;
import java.math.*;
/**
* weight=weight+learnRate*delta*input
*
* output=f(s)
*
* @author huicongwm
*
*/
public class Perceptron {
/**
* Weight.
*/
public double weight[];
/**
* Input and Output.
*/
public double input[], output;
/**
* Learn rate.
*/
public double learnRate;
/**
* Threshold.
*/
public double threshold;
/**
* The net structure.
*/
public Perceptron preP[], postP[];
/**
* Now use delta.
*/
public double myDelta;
/**
* The hidden layer.
*/
public int layer, id;
public Perceptron(int layer, int id, int nin, double lrate, double threshold) {
weight = new double[nin];
for (int i = 0; i < nin; i++)
weight[i] = (new Random()).nextDouble() - 0.5;
input = new double[nin];
output = 0;
learnRate = lrate;
this.threshold = threshold;
this.layer = layer;
this.id = id;
}
public Perceptron(InputStream f) {
load(f);
}
public void initWeight(double w[]) {
weight = w;
}
public void link(Perceptron[] pre, Perceptron[] post) {
preP = pre;
postP = post;
}
double getS(double a[], double b[]) {
double c = 0;
for (int i = 0; i < a.length && i < b.length; i++)
c += a[i] * b[i];
return c - threshold;
}
/**
* Simple function.
*
* Overiable.
*
* @param s=weight
* dotProduct input.
* @return
*/
double f(double s) {
return s;
}
/**
* f'
*
* Overiable.
*
* @param s
* @return
*/
double differential(double s) {
return 1;
}
/**
* delta function for updata.
*
* Overiable.
*
* @param s
* @param d
* @return
*/
double delta(double s, double d) {
myDelta = (d - f(s)) * differential(s);
return myDelta;
}
double delta(double s) {
double ud = 0;
for (int l = 0; l < postP.length; l++) {
ud += postP[l].myDelta * postP[l].weight[id];
}
myDelta = ud * differential(s);
return myDelta;
}
public void updata(double d) {
if (postP == null) {
finalUpdata(d);
} else {
intermediateUpdata();
}
}
public void finalUpdata(double d) {
double s = getS(weight, input);
double crate = learnRate * delta(s, d);
for (int i = 0; i < weight.length; i++)
weight[i] += crate * input[i];
}
public void intermediateUpdata() {
double s = getS(weight, input);
double crate = learnRate * delta(s);
for (int i = 0; i < weight.length; i++)
weight[i] += crate * input[i];
}
/**
* Main input function
*
* @param x
*/
public void in(double x[]) {
input = x;
}
/**
* Main output function
*
* @return
*/
public double out() {
double s = getS(weight, input);
output = f(s);
return output;
}
public String toString() {
String ss = "Perceptron ";
ss += "(Layer:" + layer + ")";
ss += "(index:" + id + ")";
ss += "(lRate:" + learnRate + ")";
ss += "(Threshold:" + threshold + ")";
for (int i = 0; i < weight.length; i++)
ss += "(w[" + i + "]=" + weight[i] + ")";
ss += "(NowDelta:" + myDelta + ")";
return ss;
}
/**
* File format
*
* @return
*/
public String toFile() {
String ss = "";
ss += layer + " ";
ss += id + " ";
ss += learnRate + "";
ss += threshold + " ";
ss += weight.length + " ";
for (int i = 0; i < weight.length; i++)
ss += weight[i] + " ";
ss += myDelta;
ss += "\n";
return ss;
}
/**
* Save
*
* @param fw
*/
public void save(OutputStream fw) {
try {
DataOutputStream dos = new DataOutputStream(fw);
dos.writeInt(layer);
dos.writeInt(id);
dos.writeDouble(learnRate);
dos.writeDouble(threshold);
dos.writeInt(weight.length);
for (int i = 0; i < weight.length; i++)
dos.writeDouble(weight[i]);
dos.writeDouble(myDelta);
} catch (Exception e) {
// TODO 自动生成 catch 块
e.printStackTrace();
}
}
/**
* Load
*
* @param fw
*/
public void load(InputStream fw) {
try {
DataInputStream dis = new DataInputStream(fw);
layer = dis.readInt();
id = dis.readInt();
learnRate = dis.readDouble();
threshold = dis.readDouble();
int nin = dis.readInt();
input = new double[nin];
output = 0;
weight = new double[nin];
for (int i = 0; i < nin; i++)
weight[i] = (new Random()).nextDouble();
for (int i = 0; i < nin; i++)
weight[i] = dis.readDouble();
myDelta = dis.readDouble();
} catch (IOException e) {
e.printStackTrace();
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -