📄 neuralnetwork.java
字号:
*/ private void calculateOutputs() { for (int noc = 0; noc < m_numClasses; noc++) { //get the values. m_outputs[noc].outputValue(true); } } /** * This will cause the error values to be calculated for all nodes. * Note that the m_currentInstance is used to calculate these values. * Also the output values should have been calculated first. * @return The squared error. */ private double calculateErrors() throws Exception { double ret = 0, temp = 0; for (int noc = 0; noc < m_numAttributes; noc++) { //get the errors. m_inputs[noc].errorValue(true); } for (int noc = 0; noc < m_numClasses; noc++) { temp = m_outputs[noc].errorValue(false); ret += temp * temp; } return ret; } /** * This will cause the weight values to be updated based on the learning * rate, momentum and the errors that have been calculated for each node. * @param l The learning rate to update with. * @param m The momentum to update with. */ private void updateNetworkWeights(double l, double m) { for (int noc = 0; noc < m_numClasses; noc++) { //update weights m_outputs[noc].updateWeights(l, m); } } /** * This creates the required input units. */ private void setupInputs() throws Exception { m_inputs = new NeuralEnd[m_numAttributes]; int now = 0; for (int noa = 0; noa < m_numAttributes+1; noa++) { if (m_instances.classIndex() != noa) { m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name()); m_inputs[noa - now].setX(.1); m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1)); m_inputs[noa - now].setLink(true, noa); } else { now = 1; } } } /** * This creates the required output units. */ private void setupOutputs() throws Exception { m_outputs = new NeuralEnd[m_numClasses]; for (int noa = 0; noa < m_numClasses; noa++) { if (m_numeric) { m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name()); } else { m_outputs[noa]= new NeuralEnd(m_instances.classAttribute().value(noa)); } m_outputs[noa].setX(.9); m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1)); m_outputs[noa].setLink(false, noa); NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random, m_sigmoidUnit); m_nextId++; temp.setX(.75); temp.setY((noa + 1.0) / (m_numClasses + 1)); addNode(temp); NeuralConnection.connect(temp, m_outputs[noa]); } } /** * Call this function to automatically generate the hidden units */ private void setupHiddenLayer() { StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ","); int val = 0; //num of nodes in a layer int prev = 0; //used to remember the previous layer int num = tok.countTokens(); //number of layers String c; for (int noa = 0; noa < num; noa++) { //note that I am using the Double to get the value rather than the //Integer class, because for some reason the Double implementation can //handle leading white space and the integer version can't!?! c = tok.nextToken().trim(); if (c.equals("a")) { val = (m_numAttributes + m_numClasses) / 2; } else if (c.equals("i")) { val = m_numAttributes; } else if (c.equals("o")) { val = m_numClasses; } else if (c.equals("t")) { val = m_numAttributes + m_numClasses; } else { val = Double.valueOf(c).intValue(); } for (int nob = 0; nob < val; nob++) { NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random, m_sigmoidUnit); m_nextId++; temp.setX(.5 / (num) * noa + .25); temp.setY((nob + 1.0) / (val + 1)); addNode(temp); if (noa > 0) { //then do connections for (int noc = m_neuralNodes.length - nob - 1 - prev; noc < m_neuralNodes.length - nob - 1; noc++) { NeuralConnection.connect(m_neuralNodes[noc], temp); } } } prev = val; } tok = new StringTokenizer(m_hiddenLayers, ","); c = tok.nextToken(); if (c.equals("a")) { val = (m_numAttributes + m_numClasses) / 2; } else if (c.equals("i")) { val = m_numAttributes; } else if (c.equals("o")) { val = m_numClasses; } else if (c.equals("t")) { val = m_numAttributes + m_numClasses; } else { val = Double.valueOf(c).intValue(); } if (val == 0) { for (int noa = 0; noa < m_numAttributes; noa++) { for (int nob = 0; nob < m_numClasses; nob++) { NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]); } } } else { for (int noa = 0; noa < m_numAttributes; noa++) { for (int nob = m_numClasses; nob < m_numClasses + val; nob++) { NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]); } } for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length; noa++) { for (int nob = 0; nob < m_numClasses; nob++) { NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]); } } } } /** * This will go through all the nodes and check if they are connected * to a pure output unit. If so they will be set to be linear units. * If not they will be set to be sigmoid units. */ private void setEndsToLinear() { for (int noa = 0; noa < m_neuralNodes.length; noa++) { if ((m_neuralNodes[noa].getType() & NeuralConnection.OUTPUT) == NeuralConnection.OUTPUT) { ((NeuralNode)m_neuralNodes[noa]).setMethod(m_linearUnit); } else { ((NeuralNode)m_neuralNodes[noa]).setMethod(m_sigmoidUnit); } } } /** * Call this function to build and train a neural network for the training * data provided. * @param i The training data. * @exception Throws exception if can't build classification properly. */ public void buildClassifier(Instances i) throws Exception { if (i.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (i.numInstances() == 0) { throw new IllegalArgumentException("No training instances."); } m_epoch = 0; m_error = 0; m_instances = null; m_currentInstance = null; m_controlPanel = null; m_nodePanel = null; m_outputs = new NeuralEnd[0]; m_inputs = new NeuralEnd[0]; m_numAttributes = 0; m_numClasses = 0; m_neuralNodes = new NeuralConnection[0]; m_selected = new FastVector(4); m_graphers = new FastVector(2); m_nextId = 0; m_stopIt = true; m_stopped = true; m_accepted = false; m_instances = new Instances(i); m_instances.deleteWithMissingClass(); if (m_instances.numInstances() == 0) { m_instances = null; throw new IllegalArgumentException("All class values missing."); } m_random = new Random(m_randomSeed); m_instances.randomize(m_random); if (m_useNomToBin) { m_nominalToBinaryFilter = new NominalToBinary(); m_nominalToBinaryFilter.setInputFormat(m_instances); m_instances = Filter.useFilter(m_instances, m_nominalToBinaryFilter); } m_numAttributes = m_instances.numAttributes() - 1; m_numClasses = m_instances.numClasses(); setClassType(m_instances); //this sets up the validation set. Instances valSet = null; //numinval is needed later int numInVal = (int)(m_valSize / 100.0 * m_instances.numInstances()); if (m_valSize > 0) { if (numInVal == 0) { numInVal = 1; } valSet = new Instances(m_instances, 0, numInVal); } /////////// setupInputs(); setupOutputs(); if (m_autoBuild) { setupHiddenLayer(); } ///////////////////////////// //this sets up the gui for usage if (m_gui) { m_win = new JFrame(); m_win.addWindowListener(new WindowAdapter() { public void windowClosing(WindowEvent e) { boolean k = m_stopIt; m_stopIt = true; int well =JOptionPane.showConfirmDialog(m_win, "Are You Sure...\n" + "Click Yes To Accept" + " The Neural Network" + "\n Click No To Return", "Accept Neural Network", JOptionPane.YES_NO_OPTION); if (well == 0) { m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); m_accepted = true; blocker(false); } else { m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE); } m_stopIt = k; } }); m_win.getContentPane().setLayout(new BorderLayout()); m_win.setTitle("Neural Network"); m_nodePanel = new NodePanel(); JScrollPane sp = new JScrollPane(m_nodePanel, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); m_controlPanel = new ControlPanel(); m_win.getContentPane().add(sp, BorderLayout.CENTER); m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH); m_win.setSize(640, 480); m_win.show(); } //This sets up the initial state of the gui if (m_gui) { blocker(true); m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); } //For silly situations in which the network gets accepted before training //commenses if (m_numeric) { setEndsToLinear(); } if (m_accepted) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; m_instances = new Instances(m_instances, 0); return; } //connections done. double right = 0; double driftOff = 0; double lastRight = Double.POSITIVE_INFINITY; double tempRate; double totalWeight = 0; double totalValWeight = 0; double origRate = m_learningRate; //only used for when reset //ensure that at least 1 instance is trained through. if (numInVal == m_instances.numInstances()) { numInVal--; } if (numInVal < 0) { numInVal = 0; } for (int noa = numInVal; noa < m_instances.numInstances(); noa++) { if (!m_instances.instance(noa).classIsMissing()) { totalWeight += m_instances.instance(noa).weight(); } } if (m_valSize != 0) { for (int noa = 0; noa < valSet.numInstances(); noa++) { if (!valSet.instance(noa).classIsMissing()) { totalValWeight += valSet.instance(noa).weight(); } } } m_stopped = false; for (int noa = 1; noa < m_numEpochs + 1; noa++) { right = 0; for (int nob = numInVal; nob < m_instances.numInstances(); nob++) { m_currentInstance = m_instances.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating (and training occurs, for the //training set resetNetwork(); calculateOutputs(); tempRate = m_learningRate * m_currentInstance.weight(); if (m_decay) { tempRate /= noa; } right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight(); updateNetworkWeights(tempRate, m_momentum); } } right /= totalWeight; if (Double.isInfinite(right) || Double.isNaN(right)) { if (!m_reset) { m_instances = null; throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate."); } else { //reset the network m_learningRate /= 2; buildClassifier(i); m_learningRate = origRate; m_instances = new Instances(m_instances, 0); return; } } ////////////////////////do validation testing if applicable if (m_valSize != 0) { right = 0; for (int nob = 0; nob < valSet.numInstances(); nob++) { m_currentInstance = valSet.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating occurs, for the validation set resetNetwork(); calculateOutputs(); right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight(); //note 'right' could be calculated here just using //the calculate output values. This would be faster. //be less modular } } if (right < lastRight) { driftOff = 0; } else { driftOff++; } lastRight = right; if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) { m_accepted = true; } right /= totalValWeight; } m_epoch = noa; m_error = right; //shows what the neuralnet is upto if a gui exists. updateDisplay(); //This junction controls what state the gui is in at the end of each //epoch, Such as if it is paused, if it is resumable etc... if (m_gui) { while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && !m_accepted) { m_stopIt = true; m_stopped = true; if (m_epoch >= m_numEpochs && m_valSize == 0) { m_controlPanel.m_startStop.setEnabled(false); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -