📄 multilayerperceptron.java
字号:
JOptionPane.YES_NO_OPTION); if (well == 0) { m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); m_accepted = true; blocker(false); } else { m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE); } m_stopIt = k; } }); m_win.getContentPane().setLayout(new BorderLayout()); m_win.setTitle("Neural Network"); m_nodePanel = new NodePanel(); JScrollPane sp = new JScrollPane(m_nodePanel, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); m_controlPanel = new ControlPanel(); m_win.getContentPane().add(sp, BorderLayout.CENTER); m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH); m_win.setSize(640, 480); m_win.setVisible(true); } //This sets up the initial state of the gui if (m_gui) { blocker(true); m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); } //For silly situations in which the network gets accepted before training //commenses if (m_numeric) { setEndsToLinear(); } if (m_accepted) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; m_instances = new Instances(m_instances, 0); return; } //connections done. double right = 0; double driftOff = 0; double lastRight = Double.POSITIVE_INFINITY; double tempRate; double totalWeight = 0; double totalValWeight = 0; double origRate = m_learningRate; //only used for when reset //ensure that at least 1 instance is trained through. if (numInVal == m_instances.numInstances()) { numInVal--; } if (numInVal < 0) { numInVal = 0; } for (int noa = numInVal; noa < m_instances.numInstances(); noa++) { if (!m_instances.instance(noa).classIsMissing()) { totalWeight += m_instances.instance(noa).weight(); } } if (m_valSize != 0) { for (int noa = 0; noa < valSet.numInstances(); noa++) { if (!valSet.instance(noa).classIsMissing()) { totalValWeight += valSet.instance(noa).weight(); } } } m_stopped = false; for (int noa = 1; noa < m_numEpochs + 1; noa++) { right = 0; for (int nob = numInVal; nob < m_instances.numInstances(); nob++) { m_currentInstance = m_instances.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating (and training occurs, for the //training set resetNetwork(); calculateOutputs(); tempRate = m_learningRate * m_currentInstance.weight(); if (m_decay) { tempRate /= noa; } right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight(); updateNetworkWeights(tempRate, m_momentum); } } right /= totalWeight; if (Double.isInfinite(right) || Double.isNaN(right)) { if (!m_reset) { m_instances = null; throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate."); } else { //reset the network m_learningRate /= 2; buildClassifier(i); m_learningRate = origRate; m_instances = new Instances(m_instances, 0); return; } } ////////////////////////do validation testing if applicable if (m_valSize != 0) { right = 0; for (int nob = 0; nob < valSet.numInstances(); nob++) { m_currentInstance = valSet.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating occurs, for the validation set resetNetwork(); calculateOutputs(); right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight(); //note 'right' could be calculated here just using //the calculate output values. This would be faster. //be less modular } } if (right < lastRight) { driftOff = 0; } else { driftOff++; } lastRight = right; if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) { m_accepted = true; } right /= totalValWeight; } m_epoch = noa; m_error = right; //shows what the neuralnet is upto if a gui exists. updateDisplay(); //This junction controls what state the gui is in at the end of each //epoch, Such as if it is paused, if it is resumable etc... if (m_gui) { while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && !m_accepted) { m_stopIt = true; m_stopped = true; if (m_epoch >= m_numEpochs && m_valSize == 0) { m_controlPanel.m_startStop.setEnabled(false); } else { m_controlPanel.m_startStop.setEnabled(true); } m_controlPanel.m_startStop.setText("Start"); m_controlPanel.m_startStop.setActionCommand("Start"); m_controlPanel.m_changeEpochs.setEnabled(true); m_controlPanel.m_changeLearning.setEnabled(true); m_controlPanel.m_changeMomentum.setEnabled(true); blocker(true); if (m_numeric) { setEndsToLinear(); } } m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); m_stopped = false; //if the network has been accepted stop the training loop if (m_accepted) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; m_instances = new Instances(m_instances, 0); return; } } if (m_accepted) { m_instances = new Instances(m_instances, 0); return; } } if (m_gui) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; } m_instances = new Instances(m_instances, 0); } /** * Call this function to predict the class of an instance once a * classification model has been built with the buildClassifier call. * @param i The instance to classify. * @return A double array filled with the probabilities of each class type. * @throws if can't classify instance. */ public double[] distributionForInstance(Instance i) throws Exception { if (m_useNomToBin) { m_nominalToBinaryFilter.input(i); m_currentInstance = m_nominalToBinaryFilter.output(); } else { m_currentInstance = i; } if (m_normalizeAttributes) { for (int noa = 0; noa < m_instances.numAttributes(); noa++) { if (noa != m_instances.classIndex()) { if (m_attributeRanges[noa] != 0) { m_currentInstance.setValue(noa, (m_currentInstance.value(noa) - m_attributeBases[noa]) / m_attributeRanges[noa]); } else { m_currentInstance.setValue(noa, m_currentInstance.value(noa) - m_attributeBases[noa]); } } } } resetNetwork(); //since all the output values are needed. //They are calculated manually here and the values collected. double[] theArray = new double[m_numClasses]; for (int noa = 0; noa < m_numClasses; noa++) { theArray[noa] = m_outputs[noa].outputValue(true); } if (m_instances.classAttribute().isNumeric()) { return theArray; } //now normalize the array double count = 0; for (int noa = 0; noa < m_numClasses; noa++) { count += theArray[noa]; } if (count <= 0) { return null; } for (int noa = 0; noa < m_numClasses; noa++) { theArray[noa] /= count; } return theArray; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(14); newVector.addElement(new Option( "\tLearning Rate for the backpropagation algorithm.\n" +"\t(Value should be between 0 - 1, Default = 0.3).", "L", 1, "-L <learning rate>")); newVector.addElement(new Option( "\tMomentum Rate for the backpropagation algorithm.\n" +"\t(Value should be between 0 - 1, Default = 0.2).", "M", 1, "-M <momentum>")); newVector.addElement(new Option( "\tNumber of epochs to train through.\n" +"\t(Default = 500).", "N", 1,"-N <number of epochs>")); newVector.addElement(new Option( "\tPercentage size of validation set to use to terminate" + " training (if this is non zero it can pre-empt num of epochs.\n" +"\t(Value should be between 0 - 100, Default = 0).", "V", 1, "-V <percentage size of validation set>")); newVector.addElement(new Option( "\tThe value used to seed the random number generator" + "\t(Value should be >= 0 and and a long, Default = 0).", "S", 1, "-S <seed>")); newVector.addElement(new Option( "\tThe consequetive number of errors allowed for validation" + " testing before the netwrok terminates." + "\t(Value should be > 0, Default = 20).", "E", 1, "-E <threshold for number of consequetive errors>")); newVector.addElement(new Option( "\tGUI will be opened.\n" +"\t(Use this to bring up a GUI).", "G", 0,"-G")); newVector.addElement(new Option( "\tAutocreation of the network connections will NOT be done.\n" +"\t(This will be ignored if -G is NOT set)", "A", 0,"-A")); newVector.addElement(new Option( "\tA NominalToBinary filter will NOT automatically be used.\n" +"\t(Set this to not use a NominalToBinary filter).", "B", 0,"-B")); newVector.addElement(new Option( "\tThe hidden layers to be created for the network.\n" +"\t(Value should be a list of comma seperated Natural numbers" + " or the letters 'a' = (attribs + classes) / 2, 'i'" + " = attribs, 'o' = classes, 't' = attribs .+ classes)" + " For wildcard values" + ",Default = a).", "H", 1, "-H <comma seperated numbers for nodes on each layer>")); newVector.addElement(new Option( "\tNormalizing a numeric class will NOT be done.\n" +"\t(Set this to not normalize the class if it's numeric).", "C", 0,"-C")); newVector.addElement(new Option( "\tNormalizing the attributes will NOT be done.\n" +"\t(Set this to not normalize the attributes).", "I", 0,"-I")); newVector.addElement(new Option( "\tReseting the network will NOT be allowed.\n" +"\t(Set this to not allow the network to reset).", "R", 0,"-R")); newVector.addElement(new Option( "\tLearning rate decay will occur.\n" +"\t(Set this to cause the learning rate to decay).", "D", 0,"-D")); return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -L <learning rate> * Learning Rate for the backpropagation algorithm. * (Value should be between 0 - 1, Default = 0.3).</pre> * * <pre> -M <momentum> * Momentum Rate for the backpropagation algorithm. * (Value should be between 0 - 1, Default = 0.2).</pre> * * <pre> -N <number of epochs> * Number of epochs to train through. * (Default = 500).</pre> * * <pre> -V <percentage size of validation set> * Percentage size of validation set to use to terminate training (if this is non zero it can pre-empt num of epochs. * (Value should be between 0 - 100, Default = 0).</pre> * * <pre> -S <seed> * The value used to seed the random number generator (Value should be >= 0 and and a long, Default = 0).</pre> * * <pre> -E <threshold for number of consequetive errors> * The consequetive number of errors allowed for validation testing before the netwrok terminates. (Value should be > 0, Default = 20).</pre> * * <pre> -G * GUI will be opened. * (Use this to bring up a GUI).</pre> * * <pre> -A * Autocreation of the network connections will NOT be done. * (This will be ignored if -G is NOT set)</pre> * * <pre> -B * A NominalToBinary filter will NOT automatically be used. * (Set this to not use a NominalToBinary filter).</pre> * * <pre> -H <comma seperated numbers for nodes on each layer> * The hidden layers to be created for the network. * (Value should be a list of comma seperated Natural numbers or the letters 'a' = (attribs + classes) / 2, 'i' = attribs, 'o' = classes, 't' = attribs .+ classes) For wildcard values,Default = a).</pre> * * <pre> -C * Normalizing a numeric class will NOT be done. * (Set this to not normalize the class if it's numeric).</pre> * * <pre> -I * Normalizing the attributes will NOT be done. * (Set this to not normalize the attributes).</pre> * * <pre> -R * Reseting the network will NOT be allowed. * (Set this to not allow the network to reset).</pre> * * <pre> -D * Learning rate decay will occur. * (Set this to cause the learning rate to decay).</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { //the defaults can be found here!!!! String learningString = Utils.getOption('L', options); if (learningString.length() != 0) { setLearningRate((new Double(learningString)).doubleValue()); } else { setLearningRate(0.3); } String momentumString = Utils.getOption('M', options); if (momentumString.length() != 0) { setMomentum((new Double(momentumString)).doubleValue()); } else { setMomentum(0.2); } String epochsString = Utils.getOption('N', options); if (epochsString.length() != 0) { setTrainingTime(Integer.parseInt(epochsString)); } else { setTrainingTime(500); } String valSizeString = Utils.getOption('V', options); if (valSizeString.length() != 0) { setValidationSetSize(Integer.parseInt(valSizeString)); } else { setValidationSetSize(0); } String seedString = Utils.getOption('S', options); if (seedString.length() != 0) { setRandomSeed(Long.parseLong(seedString)); } else { setRandomSeed(0); } String thresholdString = Utils.getOption('E', options); if (thresholdString.length() !=
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -