📄 neuron.java
字号:
weightsDelta[nweights] = 0;
}
// -----------------------------------------------------------------------
// Calculation methods
// -----------------------------------------------------------------------
/**
* Reset all internal calculation values.
*
* Calls recursively the reset methods of all input nodes.
*/
public void resetValues() {
if ( !Category.isMissingValue(input) || !Category.isMissingValue(output) ||
!Category.isMissingValue(error) ) {
input = Category.MISSING_VALUE;
output = Category.MISSING_VALUE;
error = Category.MISSING_VALUE;
weightsUpdated = false;
for (int i = 0; i < getParentsCount(); i++)
( (NeuralNode) getParentAt(i) ).resetValues();
}
}
/**
* Calculates input of the neuron. At this, the scalar product
* of input neuron output values (calls the method recursively) and weights
* is calculated.
*
* @param newcalc calculate the input new
* @return output value, missing if not calculated
* @exception MiningException error while calculating input value
*/
public double inputValue(boolean newcalc) throws MiningException {
if ( Category.isMissingValue(input) && newcalc ) {
// Bias value:
input = weights[0]*1.0;
// Scalar product:
for (int i = 0; i < getParentsCount(); i++) {
double out = ((NeuralNode) getParentAt(i)).outputValue(true);
input = input + weights[i+1]*out;
};
}
return input;
}
/**
* Calculates output of the neuron. At this, first the input value
* is calculated and then the activation function is applied.
*
* @param newcalc calculate the output new
* @return output value, missing if not calculated
* @exception MiningException error while calculating output value
*/
public double outputValue(boolean newcalc) throws MiningException {
if ( Category.isMissingValue(output) && newcalc ) {
// Get input value:
output = inputValue(true);
// Apply activation function:
output = activationFunction.function(output);
}
return output;
}
/**
* Calculate error of neuron.
*
* @param newcalc calculate the error new
* @return error value, missing if not calculated
* @exception MiningException error while calculating error value
*/
public double errorValue(boolean newcalc) throws MiningException {
if ( Category.isMissingValue(error) && !Category.isMissingValue(input) && newcalc ) {
// Error back propagation:
error = 0.0;
for (int i = 0; i < getChildCount(); i++) {
NeuralNode NN = (NeuralNode) getChildAt(i);
double weight = NN.getWeightAt( NN.getParentIndex(this) );
error = error + NN.errorValue(true)*weight;
}
// Derivation of activation function:
double value = inputValue(false);
error = error*activationFunction.derivation(value);
}
return error;
}
/**
* This method updates the weights of this neuron and then
* calls the same method of the input nodes recursively. Should
* be overriden by Neuron's.
*
* @param learningRate the learning rate
* @param momentum the momentum
* @exception MiningException cannot update weights
*/
public void updateWeights(double learningRate, double momentum)
throws MiningException {
if (! weightsUpdated && !Double.isNaN(error) ) {
// Learning rate times error:
double coef0 = learningRate*errorValue(false);
// Bias term:
double delta = coef0*1.0 + momentum*weightsDelta[0];
weights[0] = weights[0] + delta;
weightsDelta[0] = delta;
// Input neurons:
for (int i = 1; i < getParentsCount()+1; i++) {
delta = coef0*((NeuralNode) getParentAt(i-1)).outputValue(false);
delta = delta + momentum*weightsDelta[i];
weights[i] = weights[i] + delta;
weightsDelta[i] = delta;
};
// Recursive call to input nodes:
super.updateWeights(learningRate, momentum);
};
}
// -----------------------------------------------------------------------
// Methods of PMML handling
// -----------------------------------------------------------------------
/**
* Write neuron to PMML element.
*
* @return PMML element of neuron
* @exception MiningException
*/
public Object createPmmlObject() throws MiningException
{
// Create neuron:
com.prudsys.pdm.Adapters.PmmlVersion20.Neuron neuron =
new com.prudsys.pdm.Adapters.PmmlVersion20.Neuron();
// Set neuron ID:
neuron.setId(id);
// Set input neurons:
int npar = getParentsCount();
if (npar > 0) {
com.prudsys.pdm.Adapters.PmmlVersion20.Conn[] conn =
new com.prudsys.pdm.Adapters.PmmlVersion20.Conn[npar];
boolean addWeights = false;
if ( getNumberOfWeights() > npar )
addWeights = true;
else
System.out.println("Warning! Neuron #ID: " + id + " without valid weight vector");
for (int i = 0; i < npar; i++) {
conn[i] = new com.prudsys.pdm.Adapters.PmmlVersion20.Conn();
NeuralNode par = (NeuralNode) getParentAt(i);
conn[i].setFrom( par.getId() );
if (addWeights) conn[i].setWeight( String.valueOf( weights[i+1] ) );
};
neuron.setConn(conn);
};
// Set activation function:
if (activationFunction != null && neuralLayer.getActivationFunction() == null) {
String afname = ActivationFunction.convertTypeToPmml(activationFunction.
getFunctionType());
if (afname == null)
afname = activationFunction.getFunctionType(); // use even if no equivalent PMML name exists
neuron.setActivationFunction(afname);
};
// Set bias:
if (useBias && weights != null && weights.length > 0)
neuron.setBias( String.valueOf( weights[0] ) );
// Set threshold:
if ( !Category.isMissingValue(threshold) )
neuron.setThreshold( String.valueOf(threshold) );
return neuron;
}
/**
* Read neuron from PMML element.
*
* @param pmmlObject PMML element to read in
* @exception MiningException always thrown
*/
public void parsePmmlObject( Object pmmlObject ) throws MiningException
{
// Get neuron:
com.prudsys.pdm.Adapters.PmmlVersion20.Neuron neuron =
(com.prudsys.pdm.Adapters.PmmlVersion20.Neuron) pmmlObject;
// Get neural input ID:
id = neuron.getId();
// Get input neurons:
com.prudsys.pdm.Adapters.PmmlVersion20.Conn[] inputs = neuron.getConn();
if (inputs == null || inputs.length == 0)
throw new MiningException("neuron: " + id + " has no inputs");
int ninp = inputs.length;
inputIDs = new String[ninp];
weights = new double[ninp+1];
for (int i = 0; i < ninp; i++) {
inputIDs[i] = inputs[i].getFrom();
try {
weights[i+1] = Double.parseDouble( inputs[i].getWeight() );
}
catch (Exception ex) {
System.out.println("Warning! Invalid weight value of neuron #ID: " + id);
};
};
// Get activation function:
String afname = neuron.getActivationFunction();
if (afname != null) {
String aftype = ActivationFunction.convertPmmlToType(afname);
if (afname == null)
aftype = afname; // try even if name is unknown in PMML standard
ActivationFunction af = ActivationFunction.getInstance(aftype);
if (af != null)
activationFunction = af;
}
// Get bias:
useBias = false;
String bias = neuron.getBias();
if (bias != null) {
try {
weights[0] = Double.parseDouble(bias);
useBias = true;
}
catch (Exception ex) {
System.out.println("Warning! Invalid bias value of neuron #ID: " + id);
};
}
// Get threshold:
String thresh = neuron.getThreshold();
if (thresh != null) {
try {
threshold = Double.parseDouble(thresh);
}
catch (Exception ex) {
System.out.println("Warning! Invalid threshold value of neuron #ID: " + id);
};
}
}
// -----------------------------------------------------------------------
// Further methods
// -----------------------------------------------------------------------
/**
* Delivers string representation of neuron.
*
* @return string representation of neuron
*/
public String toString() {
String s = "Neuron, " + super.toString();
if (weights != null) {
if (useBias) s = s + ", bias = " + weights[0];
s = s + " weights: ";
for (int i = 0; i < weights.length; i++)
s = s + weights[i] + " ";
};
return s;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -