📄 neuralnetwork.java
字号:
else { m_controlPanel.m_startStop.setEnabled(true); } m_controlPanel.m_startStop.setText("Start"); m_controlPanel.m_startStop.setActionCommand("Start"); m_controlPanel.m_changeEpochs.setEnabled(true); m_controlPanel.m_changeLearning.setEnabled(true); m_controlPanel.m_changeMomentum.setEnabled(true); blocker(true); if (m_numeric) { setEndsToLinear(); } } m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); m_stopped = false; //if the network has been accepted stop the training loop if (m_accepted) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; m_instances = new Instances(m_instances, 0); return; } } if (m_accepted) { m_instances = new Instances(m_instances, 0); return; } } if (m_gui) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; } m_instances = new Instances(m_instances, 0); } /** * Call this function to predict the class of an instance once a * classification model has been built with the buildClassifier call. * @param i The instance to classify. * @return A double array filled with the probabilities of each class type. * @exception if can't classify instance. */ public double[] distributionForInstance(Instance i) throws Exception { if (m_useNomToBin) { m_nominalToBinaryFilter.input(i); m_currentInstance = m_nominalToBinaryFilter.output(); } else { m_currentInstance = i; } if (m_normalizeAttributes) { for (int noa = 0; noa < m_instances.numAttributes(); noa++) { if (noa != m_instances.classIndex()) { if (m_attributeRanges[noa] != 0) { m_currentInstance.setValue(noa, (m_currentInstance.value(noa) - m_attributeBases[noa]) / m_attributeRanges[noa]); } else { m_currentInstance.setValue(noa, m_currentInstance.value(noa) - m_attributeBases[noa]); } } } } resetNetwork(); //since all the output values are needed. //They are calculated manually here and the values collected. double[] theArray = new double[m_numClasses]; for (int noa = 0; noa < m_numClasses; noa++) { theArray[noa] = m_outputs[noa].outputValue(true); } if (m_instances.classAttribute().isNumeric()) { return theArray; } //now normalize the array double count = 0; for (int noa = 0; noa < m_numClasses; noa++) { count += theArray[noa]; } if (count <= 0) { return null; } for (int noa = 0; noa < m_numClasses; noa++) { theArray[noa] /= count; } return theArray; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(14); newVector.addElement(new Option( "\tLearning Rate for the backpropagation algorithm.\n" +"\t(Value should be between 0 - 1, Default = 0.3).", "L", 1, "-L <learning rate>")); newVector.addElement(new Option( "\tMomentum Rate for the backpropagation algorithm.\n" +"\t(Value should be between 0 - 1, Default = 0.2).", "M", 1, "-M <momentum>")); newVector.addElement(new Option( "\tNumber of epochs to train through.\n" +"\t(Default = 500).", "N", 1,"-N <number of epochs>")); newVector.addElement(new Option( "\tPercentage size of validation set to use to terminate" + " training (if this is non zero it can pre-empt num of epochs.\n" +"\t(Value should be between 0 - 100, Default = 0).", "V", 1, "-V <percentage size of validation set>")); newVector.addElement(new Option( "\tThe value used to seed the random number generator" + "\t(Value should be >= 0 and and a long, Default = 0).", "S", 1, "-S <seed>")); newVector.addElement(new Option( "\tThe consequetive number of errors allowed for validation" + " testing before the netwrok terminates." + "\t(Value should be > 0, Default = 20).", "E", 1, "-E <threshold for number of consequetive errors>")); newVector.addElement(new Option( "\tGUI will be opened.\n" +"\t(Use this to bring up a GUI).", "G", 0,"-G")); newVector.addElement(new Option( "\tAutocreation of the network connections will NOT be done.\n" +"\t(This will be ignored if -G is NOT set)", "A", 0,"-A")); newVector.addElement(new Option( "\tA NominalToBinary filter will NOT automatically be used.\n" +"\t(Set this to not use a NominalToBinary filter).", "B", 0,"-B")); newVector.addElement(new Option( "\tThe hidden layers to be created for the network.\n" +"\t(Value should be a list of comma seperated Natural numbers" + " or the letters 'a' = (attribs + classes) / 2, 'i'" + " = attribs, 'o' = classes, 't' = attribs .+ classes)" + " For wildcard values" + ",Default = a).", "H", 1, "-H <comma seperated numbers for nodes on each layer>")); newVector.addElement(new Option( "\tNormalizing a numeric class will NOT be done.\n" +"\t(Set this to not normalize the class if it's numeric).", "C", 0,"-C")); newVector.addElement(new Option( "\tNormalizing the attributes will NOT be done.\n" +"\t(Set this to not normalize the attributes).", "I", 0,"-I")); newVector.addElement(new Option( "\tReseting the network will NOT be allowed.\n" +"\t(Set this to not allow the network to reset).", "R", 0,"-R")); newVector.addElement(new Option( "\tLearning rate decay will occur.\n" +"\t(Set this to cause the learning rate to decay).", "D", 0,"-D")); return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -L num <br> * Set the learning rate. * (default 0.3) <p> * * -M num <br> * Set the momentum * (default 0.2) <p> * * -N num <br> * Set the number of epochs to train through. * (default 500) <p> * * -V num <br> * Set the percentage size of the validation set from the training to use. * (default 0 (no validation set is used, instead num of epochs is used) <p> * * -S num <br> * Set the seed for the random number generator. * (default 0) <p> * * -E num <br> * Set the threshold for the number of consequetive errors allowed during * validation testing. * (default 20) <p> * * -G <br> * Bring up a GUI for the neural net. * <p> * * -A <br> * Do not automatically create the connections in the net. * (can only be used if -G is specified) <p> * * -B <br> * Do Not automatically Preprocess the instances with a nominal to binary * filter. <p> * * -H str <br> * Set the number of nodes to be used on each layer. Each number represents * its own layer and the num of nodes on that layer. Each number should be * comma seperated. There are also the wildcards 'a', 'i', 'o', 't' * (default 4) <p> * * -C <br> * Do not automatically Normalize the class if it's numeric. <p> * * -I <br> * Do not automatically Normalize the attributes. <p> * * -R <br> * Do not allow the network to be automatically reset. <p> * * -D <br> * Cause the learning rate to decay as training is done. <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { //the defaults can be found here!!!! String learningString = Utils.getOption('L', options); if (learningString.length() != 0) { setLearningRate((new Double(learningString)).doubleValue()); } else { setLearningRate(0.3); } String momentumString = Utils.getOption('M', options); if (momentumString.length() != 0) { setMomentum((new Double(momentumString)).doubleValue()); } else { setMomentum(0.2); } String epochsString = Utils.getOption('N', options); if (epochsString.length() != 0) { setTrainingTime(Integer.parseInt(epochsString)); } else { setTrainingTime(500); } String valSizeString = Utils.getOption('V', options); if (valSizeString.length() != 0) { setValidationSetSize(Integer.parseInt(valSizeString)); } else { setValidationSetSize(0); } String seedString = Utils.getOption('S', options); if (seedString.length() != 0) { setRandomSeed(Long.parseLong(seedString)); } else { setRandomSeed(0); } String thresholdString = Utils.getOption('E', options); if (thresholdString.length() != 0) { setValidationThreshold(Integer.parseInt(thresholdString)); } else { setValidationThreshold(20); } String hiddenLayers = Utils.getOption('H', options); if (hiddenLayers.length() != 0) { setHiddenLayers(hiddenLayers); } else { setHiddenLayers("a"); } if (Utils.getFlag('G', options)) { setGUI(true); } else { setGUI(false); } //small note. since the gui is the only option that can change the other //options this should be set first to allow the other options to set //properly if (Utils.getFlag('A', options)) { setAutoBuild(false); } else { setAutoBuild(true); } if (Utils.getFlag('B', options)) { setNominalToBinaryFilter(false); } else { setNominalToBinaryFilter(true); } if (Utils.getFlag('C', options)) { setNormalizeNumericClass(false); } else { setNormalizeNumericClass(true); } if (Utils.getFlag('I', options)) { setNormalizeAttributes(false); } else { setNormalizeAttributes(true); } if (Utils.getFlag('R', options)) { setReset(false); } else { setReset(true); } if (Utils.getFlag('D', options)) { setDecay(true); } else { setDecay(false); } Utils.checkForRemainingOptions(options); } /** * Gets the current settings of NeuralNet. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [21]; int current = 0; options[current++] = "-L"; options[current++] = "" + getLearningRate(); options[current++] = "-M"; options[current++] = "" + getMomentum(); options[current++] = "-N"; options[current++] = "" + getTrainingTime(); options[current++] = "-V"; options[current++] = "" +getValidationSetSize(); options[current++] = "-S"; options[current++] = "" + getRandomSeed(); options[current++] = "-E"; options[current++] =""+getValidationThreshold(); options[current++] = "-H"; options[current++] = getHiddenLayers(); if (getGUI()) { options[current++] = "-G"; } if (!getAutoBuild()) { options[current++] = "-A"; } if (!getNominalToBinaryFilter()) { options[current++] = "-B"; } if (!getNormalizeNumericClass()) { options[current++] = "-C"; } if (!getNormalizeAttributes()) { options[current++] = "-I"; } if (!getReset()) { options[current++] = "-R"; } if (getDecay()) { options[current++] = "-D"; } while (current < options.length) { options[current++] = ""; } return options; } /** * @return string describing the model. */ public String toString() { StringBuffer model = new StringBuffer(m_neuralNodes.length * 100); //just a rough size guess NeuralNode con; double[] weights; NeuralConnection[] inputs; for (int noa = 0; noa < m_neuralNodes.length; noa++) { con = (NeuralNode) m_neuralNodes[noa]; //this would need a change //for items other than nodes!!! weights = con.getWeights(); inputs = con.getInputs(); if (con.getMethod() instanceof SigmoidUnit) { model.append("Sigmoid "); } else if (con.getMethod() instanceof LinearUnit) { model.append("Linear "); } model.append("Node " + con.getId() + "\n Inputs Weights\n"); model.append(" Threshold " + weights[0] + "\n"); for (int nob = 1; nob < con.getNumInputs() + 1; nob++) { if ((inputs[nob - 1].getType() & NeuralConnection.PURE_INPUT) == NeuralConnection.PURE_INPUT) { model.append(" Attrib " + m_instances.attribute(((NeuralEnd)inputs[nob-1]). getLink()).name() + " " + weights[nob] + "\n"); } else { model.append(" Node " + inputs[nob-1].getId() + " " + weights[nob] + "\n"); } } } //now put in the ends for (int noa = 0; noa < m_outputs.length; noa++) { inputs = m_outputs[noa].getInputs(); model.append("Class " + m_instances.classAttribute(). value(m_outputs[noa].getLink()) + "\n Input\n"); for (int nob = 0; nob < m_outputs[noa].getNumInputs(); nob++) { if ((inputs[nob].getType() & NeuralConnection.PURE_INPUT) == NeuralConnection.PURE_INPUT) { model.append(" Attrib " + m_instances.attribute(((NeuralEnd)inputs[nob]). getLink()).name() + "\n"); } else { model.append(" Node " + inputs[nob].getId() + "\n"); } } } return model.toString(); } /** * This will return a string describing the classifier. * @return The string. */ public String globalInfo() { return "This neural network uses backpropagation to train."; } /** * @return a string to describe the learning rate option. */ public String learningRateTipText() { return "The amount the" + " weights are updated."; } /** * @return a string to describe the momentum option. */ public String momentumTipText() { return "Momentum applied to the weights during updating."; } /** * @return a string to describe the AutoBuild op
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -