📄 multilayerperceptron.java
字号:
throw new Exception("Network cannot train. Try restarting with a" +
" smaller learning rate.");
}
else {
//reset the network
m_learningRate /= 2;
buildClassifier(i);
m_learningRate = origRate;
m_instances = new Instances(m_instances, 0);
return;
}
}
////////////////////////do validation testing if applicable
if (m_valSize != 0) {
right = 0;
for (int nob = 0; nob < valSet.numInstances(); nob++) {
m_currentInstance = valSet.instance(nob);
if (!m_currentInstance.classIsMissing()) {
//this is where the network updating occurs, for the validation set
resetNetwork();
calculateOutputs();
right += (calculateErrors() / valSet.numClasses())
* m_currentInstance.weight();
//note 'right' could be calculated here just using
//the calculate output values. This would be faster.
//be less modular
}
}
if (right < lastRight) {
driftOff = 0;
}
else {
driftOff++;
}
lastRight = right;
if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
m_accepted = true;
}
right /= totalValWeight;
}
m_epoch = noa;
m_error = right;
//shows what the neuralnet is upto if a gui exists.
updateDisplay();
//This junction controls what state the gui is in at the end of each
//epoch, Such as if it is paused, if it is resumable etc...
if (m_gui) {
while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) &&
!m_accepted) {
m_stopIt = true;
m_stopped = true;
if (m_epoch >= m_numEpochs && m_valSize == 0) {
m_controlPanel.m_startStop.setEnabled(false);
}
else {
m_controlPanel.m_startStop.setEnabled(true);
}
m_controlPanel.m_startStop.setText("Start");
m_controlPanel.m_startStop.setActionCommand("Start");
m_controlPanel.m_changeEpochs.setEnabled(true);
m_controlPanel.m_changeLearning.setEnabled(true);
m_controlPanel.m_changeMomentum.setEnabled(true);
blocker(true);
if (m_numeric) {
setEndsToLinear();
}
}
m_controlPanel.m_changeEpochs.setEnabled(false);
m_controlPanel.m_changeLearning.setEnabled(false);
m_controlPanel.m_changeMomentum.setEnabled(false);
m_stopped = false;
//if the network has been accepted stop the training loop
if (m_accepted) {
m_win.dispose();
m_controlPanel = null;
m_nodePanel = null;
m_instances = new Instances(m_instances, 0);
return;
}
}
if (m_accepted) {
m_instances = new Instances(m_instances, 0);
return;
}
}
if (m_gui) {
m_win.dispose();
m_controlPanel = null;
m_nodePanel = null;
}
m_instances = new Instances(m_instances, 0);
}
/**
* Call this function to predict the class of an instance once a
* classification model has been built with the buildClassifier call.
* @param i The instance to classify.
* @return A double array filled with the probabilities of each class type.
* @exception if can't classify instance.
*/
public double[] distributionForInstance(Instance i) throws Exception {
if (m_useNomToBin) {
m_nominalToBinaryFilter.input(i);
m_currentInstance = m_nominalToBinaryFilter.output();
}
else {
m_currentInstance = i;
}
if (m_normalizeAttributes) {
for (int noa = 0; noa < m_instances.numAttributes(); noa++) {
if (noa != m_instances.classIndex()) {
if (m_attributeRanges[noa] != 0) {
m_currentInstance.setValue(noa, (m_currentInstance.value(noa) -
m_attributeBases[noa]) /
m_attributeRanges[noa]);
}
else {
m_currentInstance.setValue(noa, m_currentInstance.value(noa) -
m_attributeBases[noa]);
}
}
}
}
resetNetwork();
//since all the output values are needed.
//They are calculated manually here and the values collected.
double[] theArray = new double[m_numClasses];
for (int noa = 0; noa < m_numClasses; noa++) {
theArray[noa] = m_outputs[noa].outputValue(true);
}
if (m_instances.classAttribute().isNumeric()) {
return theArray;
}
//now normalize the array
double count = 0;
for (int noa = 0; noa < m_numClasses; noa++) {
count += theArray[noa];
}
if (count <= 0) {
return null;
}
for (int noa = 0; noa < m_numClasses; noa++) {
theArray[noa] /= count;
}
return theArray;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(14);
newVector.addElement(new Option(
"\tLearning Rate for the backpropagation algorithm.\n"
+"\t(Value should be between 0 - 1, Default = 0.3).",
"L", 1, "-L <learning rate>"));
newVector.addElement(new Option(
"\tMomentum Rate for the backpropagation algorithm.\n"
+"\t(Value should be between 0 - 1, Default = 0.2).",
"M", 1, "-M <momentum>"));
newVector.addElement(new Option(
"\tNumber of epochs to train through.\n"
+"\t(Default = 500).",
"N", 1,"-N <number of epochs>"));
newVector.addElement(new Option(
"\tPercentage size of validation set to use to terminate" +
" training (if this is non zero it can pre-empt num of epochs.\n"
+"\t(Value should be between 0 - 100, Default = 0).",
"V", 1, "-V <percentage size of validation set>"));
newVector.addElement(new Option(
"\tThe value used to seed the random number generator" +
"\t(Value should be >= 0 and and a long, Default = 0).",
"S", 1, "-S <seed>"));
newVector.addElement(new Option(
"\tThe consequetive number of errors allowed for validation" +
" testing before the netwrok terminates." +
"\t(Value should be > 0, Default = 20).",
"E", 1, "-E <threshold for number of consequetive errors>"));
newVector.addElement(new Option(
"\tGUI will be opened.\n"
+"\t(Use this to bring up a GUI).",
"G", 0,"-G"));
newVector.addElement(new Option(
"\tAutocreation of the network connections will NOT be done.\n"
+"\t(This will be ignored if -G is NOT set)",
"A", 0,"-A"));
newVector.addElement(new Option(
"\tA NominalToBinary filter will NOT automatically be used.\n"
+"\t(Set this to not use a NominalToBinary filter).",
"B", 0,"-B"));
newVector.addElement(new Option(
"\tThe hidden layers to be created for the network.\n"
+"\t(Value should be a list of comma seperated Natural numbers" +
" or the letters 'a' = (attribs + classes) / 2, 'i'" +
" = attribs, 'o' = classes, 't' = attribs .+ classes)" +
" For wildcard values" +
",Default = a).",
"H", 1, "-H <comma seperated numbers for nodes on each layer>"));
newVector.addElement(new Option(
"\tNormalizing a numeric class will NOT be done.\n"
+"\t(Set this to not normalize the class if it's numeric).",
"C", 0,"-C"));
newVector.addElement(new Option(
"\tNormalizing the attributes will NOT be done.\n"
+"\t(Set this to not normalize the attributes).",
"I", 0,"-I"));
newVector.addElement(new Option(
"\tReseting the network will NOT be allowed.\n"
+"\t(Set this to not allow the network to reset).",
"R", 0,"-R"));
newVector.addElement(new Option(
"\tLearning rate decay will occur.\n"
+"\t(Set this to cause the learning rate to decay).",
"D", 0,"-D"));
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -L num <br>
* Set the learning rate.
* (default 0.3) <p>
*
* -M num <br>
* Set the momentum
* (default 0.2) <p>
*
* -N num <br>
* Set the number of epochs to train through.
* (default 500) <p>
*
* -V num <br>
* Set the percentage size of the validation set from the training to use.
* (default 0 (no validation set is used, instead num of epochs is used) <p>
*
* -S num <br>
* Set the seed for the random number generator.
* (default 0) <p>
*
* -E num <br>
* Set the threshold for the number of consequetive errors allowed during
* validation testing.
* (default 20) <p>
*
* -G <br>
* Bring up a GUI for the neural net.
* <p>
*
* -A <br>
* Do not automatically create the connections in the net.
* (can only be used if -G is specified) <p>
*
* -B <br>
* Do Not automatically Preprocess the instances with a nominal to binary
* filter. <p>
*
* -H str <br>
* Set the number of nodes to be used on each layer. Each number represents
* its own layer and the num of nodes on that layer. Each number should be
* comma seperated. There are also the wildcards 'a', 'i', 'o', 't'
* (default 4) <p>
*
* -C <br>
* Do not automatically Normalize the class if it's numeric. <p>
*
* -I <br>
* Do not automatically Normalize the attributes. <p>
*
* -R <br>
* Do not allow the network to be automatically reset. <p>
*
* -D <br>
* Cause the learning rate to decay as training is done. <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
//the defaults can be found here!!!!
String learningString = Utils.getOption('L', options);
if (learningString.length() != 0) {
setLearningRate((new Double(learningString)).doubleValue());
} else {
setLearningRate(0.3);
}
String momentumString = Utils.getOption('M', options);
if (momentumString.length() != 0) {
setMomentum((new Double(momentumString)).doubleValue());
} else {
setMomentum(0.2);
}
String epochsString = Utils.getOption('N', options);
if (epochsString.length() != 0) {
setTrainingTime(Integer.parseInt(epochsString));
} else {
setTrainingTime(500);
}
String valSizeString = Utils.getOption('V', options);
if (valSizeString.length() != 0) {
setValidationSetSize(Integer.parseInt(valSizeString));
} else {
setValidationSetSize(0);
}
String seedString = Utils.getOption('S', options);
if (seedString.length() != 0) {
setRandomSeed(Long.parseLong(seedString));
} else {
setRandomSeed(0);
}
String thresholdString = Utils.getOption('E', options);
if (thresholdString.length() != 0) {
setValidationThreshold(Integer.parseInt(thresholdString));
} else {
setValidationThreshold(20);
}
String hiddenLayers = Utils.getOption('H', options);
if (hiddenLayers.length() != 0) {
setHiddenLayers(hiddenLayers);
} else {
setHiddenLayers("a");
}
if (Utils.getFlag('G', options)) {
setGUI(true);
} else {
setGUI(false);
} //small note. since the gui is the only option that can change the other
//options this should be set first to allow the other options to set
//properly
if (Utils.getFlag('A', options)) {
setAutoBuild(false);
} else {
setAutoBuild(true);
}
if (Utils.getFlag('B', options)) {
setNominalToBinaryFilter(false);
} else {
setNominalToBinaryFilter(true);
}
if (Utils.getFlag('C', options)) {
setNormalizeNumericClass(false);
} else {
setNormalizeNumericClass(true);
}
if (Utils.getFlag('I', options)) {
setNormalizeAttributes(false);
} else {
setNormalizeAttributes(true);
}
if (Utils.getFlag('R', options)) {
setReset(false);
} else {
setReset(true);
}
if (Utils.getFlag('D', options)) {
setDecay(true);
} else {
setDecay(false);
}
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of NeuralNet.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String [] getOptions() {
String [] options = new String [21];
int current = 0;
options[current++] = "-L"; options[current++] = "" + getLearningRate();
options[current++] = "-M"; options[current++] = "" + getMomentum();
options[current++] = "-N"; options[current++] = "" + getTrainingTime();
options[current++] = "-V"; options[current++] = "" +getValidationSetSize();
options[current++] = "-S"; options[current++] = "" + getRandomSeed();
options[current++] = "-E"; options[current++] =""+getValidationThreshold();
options[current++] = "-H"; options[current++] = getHiddenLayers();
if (getGUI()) {
options[current++] = "-G";
}
if (!getAutoBuild()) {
options[current++] = "-A";
}
if (!getNominalToBinaryFilter()) {
options[current++] = "-B";
}
if (!getNormalizeNumericClass()) {
options[current++] = "-C";
}
if (!getNormalizeAttributes()) {
options[current++] = "-I";
}
if (!getReset()) {
options[current++] = "-R";
}
if (getDecay()) {
options[current++] = "-D";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* @return string describing the model.
*/
public String toString() {
StringBuffer model = new StringBuffer(m_neuralNodes.length * 100);
//just a rough size guess
NeuralNode con;
double[] weights;
NeuralConnection[] inputs;
for (int noa = 0; noa < m_neuralNodes.len
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -