📄 neuralnetwork.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * NeuralConnection.java * Copyright (C) 2000 Malcolm Ware */package weka.classifiers.neural;import java.util.*;import java.awt.*;import java.awt.event.*;import javax.swing.*;import weka.classifiers.*;import weka.core.*;import weka.filters.*;/** * A Classifier that uses backpropagation to classify instances. * This network can be built by hand, created by an algorithm or both. * The network can also be monitored and modified during training time. * The nodes in this network are all sigmoid (except for when the class * is numeric in which case the the output nodes become unthresholded linear * units). * * @author Malcolm Ware (mfw4@cs.waikato.ac.nz) * @version $Revision: 1.4.2.2 $ */public class NeuralNetwork extends DistributionClassifier implements OptionHandler, WeightedInstancesHandler { /** * Main method for testing this class. * * @param argv should contain command line options (see setOptions) */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new NeuralNetwork(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); } System.exit(0); } /** * This inner class is used to connect the nodes in the network up to * the data that they are classifying, Note that objects of this class are * only suitable to go on the attribute side or class side of the network * and not both. */ protected class NeuralEnd extends NeuralConnection { /** * the value that represents the instance value this node represents. * For an input it is the attribute number, for an output, if nominal * it is the class value. */ private int m_link; /** True if node is an input, False if it's an output. */ private boolean m_input; public NeuralEnd(String id) { super(id); m_link = 0; m_input = true; } /** * Call this function to determine if the point at x,y is on the unit. * @param g The graphics context for font size info. * @param x The x coord. * @param y The y coord. * @param w The width of the display. * @param h The height of the display. * @return True if the point is on the unit, false otherwise. */ public boolean onUnit(Graphics g, int x, int y, int w, int h) { FontMetrics fm = g.getFontMetrics(); int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2; int t = (int)(m_y * h) - fm.getHeight() / 2; if (x < l || x > l + fm.stringWidth(m_id) + 4 || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) { return false; } return true; } /** * This will draw the node id to the graphics context. * @param g The graphics context. * @param w The width of the drawing area. * @param h The height of the drawing area. */ public void drawNode(Graphics g, int w, int h) { if ((m_type & PURE_INPUT) == PURE_INPUT) { g.setColor(Color.green); } else { g.setColor(Color.orange); } FontMetrics fm = g.getFontMetrics(); int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2; int t = (int)(m_y * h) - fm.getHeight() / 2; g.fill3DRect(l, t, fm.stringWidth(m_id) + 4 , fm.getHeight() + fm.getDescent() + 4 , true); g.setColor(Color.black); g.drawString(m_id, l + 2, t + fm.getHeight() + 2); } /** * Call this function to draw the node highlighted. * @param g The graphics context. * @param w The width of the drawing area. * @param h The height of the drawing area. */ public void drawHighlight(Graphics g, int w, int h) { g.setColor(Color.black); FontMetrics fm = g.getFontMetrics(); int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2; int t = (int)(m_y * h) - fm.getHeight() / 2; g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8 , fm.getHeight() + fm.getDescent() + 8); drawNode(g, w, h); } /** * Call this to get the output value of this unit. * @param calculate True if the value should be calculated if it hasn't * been already. * @return The output value, or NaN, if the value has not been calculated. */ public double outputValue(boolean calculate) { if (Double.isNaN(m_unitValue) && calculate) { if (m_input) { if (m_currentInstance.isMissing(m_link)) { m_unitValue = 0; } else { m_unitValue = m_currentInstance.value(m_link); } } else { //node is an output. m_unitValue = 0; for (int noa = 0; noa < m_numInputs; noa++) { m_unitValue += m_inputList[noa].outputValue(true); } if (m_numeric && m_normalizeClass) { //then scale the value; //this scales linearly from between -1 and 1 m_unitValue = m_unitValue * m_attributeRanges[m_instances.classIndex()] + m_attributeBases[m_instances.classIndex()]; } } } return m_unitValue; } /** * Call this to get the error value of this unit, which in this case is * the difference between the predicted class, and the actual class. * @param calculate True if the value should be calculated if it hasn't * been already. * @return The error value, or NaN, if the value has not been calculated. */ public double errorValue(boolean calculate) { if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) { if (m_input) { m_unitError = 0; for (int noa = 0; noa < m_numOutputs; noa++) { m_unitError += m_outputList[noa].errorValue(true); } } else { if (m_currentInstance.classIsMissing()) { m_unitError = .1; } else if (m_instances.classAttribute().isNominal()) { if (m_currentInstance.classValue() == m_link) { m_unitError = 1 - m_unitValue; } else { m_unitError = 0 - m_unitValue; } } else if (m_numeric) { if (m_normalizeClass) { if (m_attributeRanges[m_instances.classIndex()] == 0) { m_unitError = 0; } else { m_unitError = (m_currentInstance.classValue() - m_unitValue ) / m_attributeRanges[m_instances.classIndex()]; //m_numericRange; } } else { m_unitError = m_currentInstance.classValue() - m_unitValue; } } } } return m_unitError; } /** * Call this to reset the value and error for this unit, ready for the next * run. This will also call the reset function of all units that are * connected as inputs to this one. * This is also the time that the update for the listeners will be * performed. */ public void reset() { if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) { m_unitValue = Double.NaN; m_unitError = Double.NaN; m_weightsUpdated = false; for (int noa = 0; noa < m_numInputs; noa++) { m_inputList[noa].reset(); } } } /** * Call this function to set What this end unit represents. * @param input True if this unit is used for entering an attribute, * False if it's used for determining a class value. * @param val The attribute number or class type that this unit represents. * (for nominal attributes). */ public void setLink(boolean input, int val) throws Exception { m_input = input; if (input) { m_type = PURE_INPUT; } else { m_type = PURE_OUTPUT; } if (val < 0 || (input && val > m_instances.numAttributes()) || (!input && m_instances.classAttribute().isNominal() && val > m_instances.classAttribute().numValues())) { m_link = 0; } else { m_link = val; } } /** * @return link for this node. */ public int getLink() { return m_link; } } /** Inner class used to draw the nodes onto.(uses the node lists!!) * This will also handle the user input. */ private class NodePanel extends JPanel { /** * The constructor. */ public NodePanel() { addMouseListener(new MouseAdapter() { public void mousePressed(MouseEvent e) { if (!m_stopped) { return; } if ((e.getModifiers() & e.BUTTON1_MASK) == e.BUTTON1_MASK) { Graphics g = NodePanel.this.getGraphics(); int x = e.getX(); int y = e.getY(); int w = NodePanel.this.getWidth(); int h = NodePanel.this.getHeight(); int u = 0; FastVector tmp = new FastVector(4); for (int noa = 0; noa < m_numAttributes; noa++) { if (m_inputs[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_inputs[noa]); selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK , true); return; } } for (int noa = 0; noa < m_numClasses; noa++) { if (m_outputs[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_outputs[noa]); selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK , true); return; } } for (int noa = 0; noa < m_neuralNodes.length; noa++) { if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_neuralNodes[noa]); selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK , true); return; } } NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random, m_sigmoidUnit); m_nextId++; temp.setX((double)e.getX() / w); temp.setY((double)e.getY() / h); tmp.addElement(temp); addNode(temp); selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK , true); } else { //then right click Graphics g = NodePanel.this.getGraphics(); int x = e.getX(); int y = e.getY(); int w = NodePanel.this.getWidth(); int h = NodePanel.this.getHeight(); int u = 0; FastVector tmp = new FastVector(4); for (int noa = 0; noa < m_numAttributes; noa++) { if (m_inputs[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_inputs[noa]); selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK , false); return; } } for (int noa = 0; noa < m_numClasses; noa++) { if (m_outputs[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_outputs[noa]); selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK , false); return; } } for (int noa = 0; noa < m_neuralNodes.length; noa++) { if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_neuralNodes[noa]); selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK , false); return; } } selection(null, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK , false); } } }); } /** * This function gets called when the user has clicked something * It will amend the current selection or connect the current selection * to the new selection. * Or if nothing was selected and the right button was used it will * delete the node. * @param v The units that were selected. * @param ctrl True if ctrl was held down. * @param left True if it was the left mouse button. */ private void selection(FastVector v, boolean ctrl, boolean left) { if (v == null) { //then unselect all. m_selected.removeAllElements(); repaint(); return; } //then exclusive or the new selection with the current one. if ((ctrl || m_selected.size() == 0) && left) { boolean removed = false; for (int noa = 0; noa < v.size(); noa++) { removed = false; for (int nob = 0; nob < m_selected.size(); nob++) { if (v.elementAt(noa) == m_selected.elementAt(nob)) { //then remove that element m_selected.removeElementAt(nob); removed = true; break; } } if (!removed) { m_selected.addElement(v.elementAt(noa));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -