📄 neuron.java
字号:
package bp;
import java.lang.Math;
import java.io.Serializable;
public class Neuron implements Serializable{
public final static int LINEAR_ACTIVATION = 0; //线形激活函数标志
public final static int BINARY_SIGMOID_ACTIVATION = 1; //sig激活函数标志
public final static int BIPOLARY_SIGMOID_ACTIVATION = 2; //tansig激活函数标志
private double flatness = 1.0; //平滑度,梯度
private double biasWeight; //定义偏差权重
private double[] weights; //定义连接权值
private double output; //定义输出
private double net; //
private double error; //定义网络误差
private int activationType; //判断神经元类型
public Neuron(int noOfWeights){
this (noOfWeights, 0.5, BINARY_SIGMOID_ACTIVATION);
}
/*定义权重的个数,随机生成weightdefaultvalue*/
public Neuron(int noOfWeights, double weightDefaultValue, int activationType) {
weights = new double[noOfWeights];
for (int i = 0; i < weights.length; ++i)
weights[i] = weightDefaultValue + ((0.5 - Math.random()) / 10);
biasWeight = weightDefaultValue + ((0.5 - Math.random()) / 10);
this.activationType = activationType;
output = 0.5;
net = 0.0; //计算网络正向输出
error = 0.0;
}
public Neuron(Neuron neuron) {
weights = new double[neuron.getNoOfWeights()];
for (int i = 0; i < weights.length; ++i)
weights[i] = neuron.getWeight (i);
activationType = neuron.getActivationType();
flatness = neuron.getFlatness();
biasWeight = neuron.getBias();
output = neuron.getOutput();
net = neuron.getNet();
error = neuron.getError();
}
public double getFlatness(){
return flatness;
}
public void setFlatness(double flatness){
this.flatness = flatness;
}
public double getError(){
return error;
}
public int getNoOfWeights(){
return weights.length;
}
public int getActivationType() {
return activationType;
}
public double getBias(){
return biasWeight;
}
public void setBias(double bias){
biasWeight = bias;
}
public double getWeight(int index){
return weights[index];
}
public void setWeight(int index, double value){
weights[index] = value;
}
public double getOutput(){
return output;
}
public void setOutput(double output){
this.output = output;
}
public double getNet(){
return net;
}
public void setActivationType(int type){
activationType = type;
}
//计算网络正向输出 net=前层神经元的输出+bias;
public void calculateNet(Layer previousLayer){
net = 0.0;
for (int i = 0; i < weights.length; ++i)
net += (previousLayer.getNeuron (i).getOutput() * weights[i]);
net += 1 * biasWeight;
}
//计算网络输出层的误差
public void calculateError(double desiredOutput){
// Error calculation for OUTPUT layer
error = (desiredOutput - output) * derivatedActivationFunction();
}
// 计算网络隐层的误差
public void calculateError(Layer nextLayer, int index){
// Error calculation for hidden layer
double nextLayerError = 0.0;
int neurons = nextLayer.getNoOfNeurons();
for (int i = 0; i < neurons; ++i)
nextLayerError += nextLayer.getNeuron (i).getError() * nextLayer.getNeuron (i).getWeight (index);
error = derivatedActivationFunction() * nextLayerError;
}
//求权值变化和偏差变化
public void updateWeights(double learningRate, Layer previousLayer){
for (int i = 0; i < weights.length; ++i)
weights[i] += (learningRate * error * previousLayer.getNeuron (i).getOutput()) ;
biasWeight += learningRate * error * 1 ;
}
//激活函数
public double activationFunction(double net){
double activated = 0.0;
switch (activationType){
case LINEAR_ACTIVATION:
activated = net; break;
case BINARY_SIGMOID_ACTIVATION:
activated = Math.pow (1 + Math.exp (-1 * flatness * net), -1); break;
case BIPOLARY_SIGMOID_ACTIVATION:
activated = (2 / (1 + Math.exp ( -1 * flatness * net))) -1 ; break;
default:
activated = 0;
}
return activated;
}
//判断不同神经元类型 求激活函数导数
public double derivatedActivationFunction(){
double derivated = 0.0;
double out = output;
switch (activationType){
case LINEAR_ACTIVATION:
derivated = 1; break; //pureline函数的导数
case BINARY_SIGMOID_ACTIVATION:
derivated = flatness * out * (1 - out); break; //sigmoid函数的导数,flatness为梯度
case BIPOLARY_SIGMOID_ACTIVATION:
derivated = flatness * (1 - Math.pow (out, 2)) ; break; //tansig函数的导数
default:
derivated = 0;
}
return derivated;
}
public String toString(){
String out = "\t\t outputvalue: " + output + " - error: " + error + " - net: " + net + " - bias: " + biasWeight + "\n";
for (int i = 0; i < weights.length; ++i)
out += "\t\t\tweight " + i + ": " + weights[i] + "\n";
return out;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -