📄 multilayerperceptron.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * MultilayerPerceptron.java * Copyright (C) 2000 Malcolm Ware */package weka.classifiers.functions;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.functions.neural.LinearUnit;import weka.classifiers.functions.neural.NeuralConnection;import weka.classifiers.functions.neural.NeuralNode;import weka.classifiers.functions.neural.SigmoidUnit;import weka.core.Capabilities;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;import weka.filters.Filter;import weka.filters.unsupervised.attribute.NominalToBinary;import java.awt.BorderLayout;import java.awt.Color;import java.awt.Component;import java.awt.Dimension;import java.awt.FontMetrics;import java.awt.Graphics;import java.awt.event.ActionEvent;import java.awt.event.ActionListener;import java.awt.event.MouseAdapter;import java.awt.event.MouseEvent;import java.awt.event.WindowAdapter;import java.awt.event.WindowEvent;import java.util.Enumeration;import java.util.Random;import java.util.StringTokenizer;import java.util.Vector;import javax.swing.BorderFactory;import javax.swing.Box;import javax.swing.BoxLayout;import javax.swing.JButton;import javax.swing.JFrame;import javax.swing.JLabel;import javax.swing.JOptionPane;import javax.swing.JPanel;import javax.swing.JScrollPane;import javax.swing.JTextField;/** <!-- globalinfo-start --> * A Classifier that uses backpropagation to classify instances.<br/> * This network can be built by hand, created by an algorithm or both. The network can also be monitored and modified during training time. The nodes in this network are all sigmoid (except for when the class is numeric in which case the the output nodes become unthresholded linear units). * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -L <learning rate> * Learning Rate for the backpropagation algorithm. * (Value should be between 0 - 1, Default = 0.3).</pre> * * <pre> -M <momentum> * Momentum Rate for the backpropagation algorithm. * (Value should be between 0 - 1, Default = 0.2).</pre> * * <pre> -N <number of epochs> * Number of epochs to train through. * (Default = 500).</pre> * * <pre> -V <percentage size of validation set> * Percentage size of validation set to use to terminate training (if this is non zero it can pre-empt num of epochs. * (Value should be between 0 - 100, Default = 0).</pre> * * <pre> -S <seed> * The value used to seed the random number generator (Value should be >= 0 and and a long, Default = 0).</pre> * * <pre> -E <threshold for number of consequetive errors> * The consequetive number of errors allowed for validation testing before the netwrok terminates. (Value should be > 0, Default = 20).</pre> * * <pre> -G * GUI will be opened. * (Use this to bring up a GUI).</pre> * * <pre> -A * Autocreation of the network connections will NOT be done. * (This will be ignored if -G is NOT set)</pre> * * <pre> -B * A NominalToBinary filter will NOT automatically be used. * (Set this to not use a NominalToBinary filter).</pre> * * <pre> -H <comma seperated numbers for nodes on each layer> * The hidden layers to be created for the network. * (Value should be a list of comma seperated Natural numbers or the letters 'a' = (attribs + classes) / 2, 'i' = attribs, 'o' = classes, 't' = attribs .+ classes) For wildcard values,Default = a).</pre> * * <pre> -C * Normalizing a numeric class will NOT be done. * (Set this to not normalize the class if it's numeric).</pre> * * <pre> -I * Normalizing the attributes will NOT be done. * (Set this to not normalize the attributes).</pre> * * <pre> -R * Reseting the network will NOT be allowed. * (Set this to not allow the network to reset).</pre> * * <pre> -D * Learning rate decay will occur. * (Set this to cause the learning rate to decay).</pre> * <!-- options-end --> * * @author Malcolm Ware (mfw4@cs.waikato.ac.nz) * @version $Revision: 1.4 $ */public class MultilayerPerceptron extends Classifier implements OptionHandler, WeightedInstancesHandler { /** for serialization */ static final long serialVersionUID = 572250905027665169L; /** * Main method for testing this class. * * @param argv should contain command line options (see setOptions) */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new MultilayerPerceptron(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); } System.exit(0); } /** * This inner class is used to connect the nodes in the network up to * the data that they are classifying, Note that objects of this class are * only suitable to go on the attribute side or class side of the network * and not both. */ protected class NeuralEnd extends NeuralConnection { /** for serialization */ static final long serialVersionUID = 7305185603191183338L; /** * the value that represents the instance value this node represents. * For an input it is the attribute number, for an output, if nominal * it is the class value. */ private int m_link; /** True if node is an input, False if it's an output. */ private boolean m_input; /** * Constructor */ public NeuralEnd(String id) { super(id); m_link = 0; m_input = true; } /** * Call this function to determine if the point at x,y is on the unit. * @param g The graphics context for font size info. * @param x The x coord. * @param y The y coord. * @param w The width of the display. * @param h The height of the display. * @return True if the point is on the unit, false otherwise. */ public boolean onUnit(Graphics g, int x, int y, int w, int h) { FontMetrics fm = g.getFontMetrics(); int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2; int t = (int)(m_y * h) - fm.getHeight() / 2; if (x < l || x > l + fm.stringWidth(m_id) + 4 || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) { return false; } return true; } /** * This will draw the node id to the graphics context. * @param g The graphics context. * @param w The width of the drawing area. * @param h The height of the drawing area. */ public void drawNode(Graphics g, int w, int h) { if ((m_type & PURE_INPUT) == PURE_INPUT) { g.setColor(Color.green); } else { g.setColor(Color.orange); } FontMetrics fm = g.getFontMetrics(); int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2; int t = (int)(m_y * h) - fm.getHeight() / 2; g.fill3DRect(l, t, fm.stringWidth(m_id) + 4 , fm.getHeight() + fm.getDescent() + 4 , true); g.setColor(Color.black); g.drawString(m_id, l + 2, t + fm.getHeight() + 2); } /** * Call this function to draw the node highlighted. * @param g The graphics context. * @param w The width of the drawing area. * @param h The height of the drawing area. */ public void drawHighlight(Graphics g, int w, int h) { g.setColor(Color.black); FontMetrics fm = g.getFontMetrics(); int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2; int t = (int)(m_y * h) - fm.getHeight() / 2; g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8 , fm.getHeight() + fm.getDescent() + 8); drawNode(g, w, h); } /** * Call this to get the output value of this unit. * @param calculate True if the value should be calculated if it hasn't * been already. * @return The output value, or NaN, if the value has not been calculated. */ public double outputValue(boolean calculate) { if (Double.isNaN(m_unitValue) && calculate) { if (m_input) { if (m_currentInstance.isMissing(m_link)) { m_unitValue = 0; } else { m_unitValue = m_currentInstance.value(m_link); } } else { //node is an output. m_unitValue = 0; for (int noa = 0; noa < m_numInputs; noa++) { m_unitValue += m_inputList[noa].outputValue(true); } if (m_numeric && m_normalizeClass) { //then scale the value; //this scales linearly from between -1 and 1 m_unitValue = m_unitValue * m_attributeRanges[m_instances.classIndex()] + m_attributeBases[m_instances.classIndex()]; } } } return m_unitValue; } /** * Call this to get the error value of this unit, which in this case is * the difference between the predicted class, and the actual class. * @param calculate True if the value should be calculated if it hasn't * been already. * @return The error value, or NaN, if the value has not been calculated. */ public double errorValue(boolean calculate) { if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) { if (m_input) { m_unitError = 0; for (int noa = 0; noa < m_numOutputs; noa++) { m_unitError += m_outputList[noa].errorValue(true); } } else { if (m_currentInstance.classIsMissing()) { m_unitError = .1; } else if (m_instances.classAttribute().isNominal()) { if (m_currentInstance.classValue() == m_link) { m_unitError = 1 - m_unitValue; } else { m_unitError = 0 - m_unitValue; } } else if (m_numeric) { if (m_normalizeClass) { if (m_attributeRanges[m_instances.classIndex()] == 0) { m_unitError = 0; } else { m_unitError = (m_currentInstance.classValue() - m_unitValue ) / m_attributeRanges[m_instances.classIndex()]; //m_numericRange; } } else { m_unitError = m_currentInstance.classValue() - m_unitValue; } } } } return m_unitError; } /** * Call this to reset the value and error for this unit, ready for the next * run. This will also call the reset function of all units that are * connected as inputs to this one. * This is also the time that the update for the listeners will be * performed. */ public void reset() { if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) { m_unitValue = Double.NaN; m_unitError = Double.NaN; m_weightsUpdated = false; for (int noa = 0; noa < m_numInputs; noa++) { m_inputList[noa].reset(); } } } /** * Call this function to set What this end unit represents. * @param input True if this unit is used for entering an attribute, * False if it's used for determining a class value. * @param val The attribute number or class type that this unit represents. * (for nominal attributes). */ public void setLink(boolean input, int val) throws Exception { m_input = input; if (input) { m_type = PURE_INPUT; } else { m_type = PURE_OUTPUT; } if (val < 0 || (input && val > m_instances.numAttributes()) || (!input && m_instances.classAttribute().isNominal() && val > m_instances.classAttribute().numValues())) { m_link = 0; } else { m_link = val; } } /** * @return link for this node. */ public int getLink() { return m_link; } } /** Inner class used to draw the nodes onto.(uses the node lists!!) * This will also handle the user input. */ private class NodePanel extends JPanel { /** for serialization */ static final long serialVersionUID = -3067621833388149984L; /** * The constructor. */ public NodePanel() { addMouseListener(new MouseAdapter() { public void mousePressed(MouseEvent e) { if (!m_stopped) { return; } if ((e.getModifiers() & MouseEvent.BUTTON1_MASK) == MouseEvent.BUTTON1_MASK && !e.isAltDown()) { Graphics g = NodePanel.this.getGraphics(); int x = e.getX(); int y = e.getY(); int w = NodePanel.this.getWidth(); int h = NodePanel.this.getHeight(); FastVector tmp = new FastVector(4); for (int noa = 0; noa < m_numAttributes; noa++) { if (m_inputs[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_inputs[noa]);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -