📄 testlearningrule.java
字号:
package net.openai.ai.nn.learning;import java.util.*;import net.openai.ai.nn.network.*;import net.openai.ai.nn.training.*;public class TestLearningRule extends LearningRule { private boolean isOutputLayer = false; // the learning rate private double alpha = 1; // the momentum term private double beta = .5; public TestLearningRule() { } public final void correctLayer(Layer layer, TrainingElement trainingElement) { //db("processing layer: " + layer.toString()); Vector neurons = layer.getNeurons(); if(neurons.isEmpty()) { db("layer has no neurons...cannot apply learning rule."); return; } calculateError(neurons, trainingElement); adjustWeights(neurons); } private final void calculateError(Vector neurons, TrainingElement trainingElement) { if(isOutputLayer) { Vector desiredOutput = trainingElement.getDesired(); int size = neurons.size(); for(int i = 0; i < size; i++) { Neuron neuron = (Neuron) neurons.elementAt(i); String desiredString = (String) desiredOutput.elementAt(i); double desired = 0; try { desired = Double.parseDouble(desiredString); } catch (NumberFormatException nfe) { db("problem parsing desired output..."); return; } double output = neuron.getOutput(); //db("desired output was: " + desired); //db("actual output was: " + output); double difference = desired - output; //db("difference was: " + difference); double error = 2*difference; //db("Error for this output was: " + error); neuron.setError(error); } } else { int size = neurons.size(); for(int i = 0; i < size; i++) { Neuron neuron = (Neuron) neurons.elementAt(i); double output = neuron.getOutput(); Vector downstreamConnections = neuron.getConnectionsFrom(); double weightedError = 0; int connectionsSize = downstreamConnections.size(); for(int j = 0; j < connectionsSize; j++) { Connection downstreamConnection = (Connection) downstreamConnections.elementAt(j); Weight weight = downstreamConnection.getWeight(); double weightValue = weight.getValue(); Neuron downstreamNeuron = downstreamConnection.getToNeuron(); double downstreamError = downstreamNeuron.getError(); weightedError += downstreamError*weightValue; } double error = output*(1 - output)*weightedError; //db("Error for this hidden neuron was: " + error); neuron.setError(error); } } } private final void adjustWeights(Vector neurons) { int size = neurons.size(); for(int i = 0; i < size; i++) { Neuron neuron = (Neuron) neurons.elementAt(i); Vector upstreamConnections = neuron.getConnectionsTo(); int connectionsSize = upstreamConnections.size(); for(int j = 0; j < connectionsSize; j++) { Connection connection = (Connection) upstreamConnections.elementAt(j); Neuron fromNeuron = connection.getFromNeuron(); double output = neuron.getOutput(); double upstreamOutput = fromNeuron.getOutput(); //db("------------------------------------------"); //db("neuron error: " + neuron.getError()); //db("neuron output: " + output); //db("upstream output: " + upstreamOutput); //db("weight delta: " + weightDelta); Weight weight = connection.getWeight(); double weightDelta = ((-1)*alpha*neuron.getError() *output*(1 - output)*upstreamOutput + beta*weight.getLastDelta()); double oldWeight = weight.getValue(); //db("old weight: " + oldWeight); double newWeight = oldWeight - weightDelta; weight.setValue(newWeight); //db("new weight: " + newWeight); } } } // to do this...we need to identify that the error for the layer // directly downstream has been calculated, and if there is no // layer downstream...then calculate this layer as the output // layer. public final boolean ready(Layer layer) { Vector neurons = layer.getNeurons(); if(neurons.isEmpty()) { db("layer has no neurons...cannot apply learning rule."); return false; } Neuron neuron = (Neuron) neurons.elementAt(0); if(neuron.isErrorSet()) return false; Vector connectionsFrom = neuron.getConnectionsFrom(); // if there are no connections from this layer, assume it's // an output layer... if(connectionsFrom.isEmpty()) { isOutputLayer = true; db("no connections from this layer, looks like an output layer."); return true; } // get a connection to a downstream neuron Connection connection = (Connection) connectionsFrom.elementAt(0); // grab one of the downstream neurons Neuron downstreamNeuron = connection.getToNeuron(); if(downstreamNeuron == null) { db("downstream neuron was null...could not process"); return false; } // if the downstream neuron has had it's error calculated // then we're ready to process this layer if(downstreamNeuron.isErrorSet()) return true; return false; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -