⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 multilayerperceptron.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    MultilayerPerceptron.java
 *    Copyright (C) 2000 Malcolm Ware
 */

package weka.classifiers.functions;

import java.util.*;
import java.awt.*;
import java.awt.event.*;
import javax.swing.*;

import weka.classifiers.functions.neural.*;
import weka.classifiers.*;
import weka.core.*;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.Filter;

/** 
 * A Classifier that uses backpropagation to classify instances.
 * This network can be built by hand, created by an algorithm or both.
 * The network can also be monitored and modified during training time.
 * The nodes in this network are all sigmoid (except for when the class
 * is numeric in which case the the output nodes become unthresholded linear
 * units).
 *
 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
 * @version $Revision: 1.1 $
 */
public class MultilayerPerceptron extends Classifier 
  implements OptionHandler, WeightedInstancesHandler {
  
  /**
   * Main method for testing this class.
   *
   * @param argv should contain command line options (see setOptions)
   */
  public static void main(String [] argv) {
    
    try {
      System.out.println(Evaluation.evaluateModel(new MultilayerPerceptron(), argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
    }
    System.exit(0);
  }
  

  /** 
   * This inner class is used to connect the nodes in the network up to
   * the data that they are classifying, Note that objects of this class are
   * only suitable to go on the attribute side or class side of the network
   * and not both.
   */
  protected class NeuralEnd extends NeuralConnection {
    
    
    /** 
     * the value that represents the instance value this node represents. 
     * For an input it is the attribute number, for an output, if nominal
     * it is the class value. 
     */
    private int m_link;
    
    /** True if node is an input, False if it's an output. */
    private boolean m_input;


    public NeuralEnd(String id) {
      super(id);

      m_link = 0;
      m_input = true;
      
    }
  
    /**
     * Call this function to determine if the point at x,y is on the unit.
     * @param g The graphics context for font size info.
     * @param x The x coord.
     * @param y The y coord.
     * @param w The width of the display.
     * @param h The height of the display.
     * @return True if the point is on the unit, false otherwise.
     */
    public boolean onUnit(Graphics g, int x, int y, int w, int h) {
      
      FontMetrics fm = g.getFontMetrics();
      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
      int t = (int)(m_y * h) - fm.getHeight() / 2;
      if (x < l || x > l + fm.stringWidth(m_id) + 4 
	  || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) {
	return false;
      }
      return true;
      
    }
   

    /**
     * This will draw the node id to the graphics context.
     * @param g The graphics context.
     * @param w The width of the drawing area.
     * @param h The height of the drawing area.
     */
    public void drawNode(Graphics g, int w, int h) {
      
      if ((m_type & PURE_INPUT) == PURE_INPUT) {
	g.setColor(Color.green);
      }
      else {
	g.setColor(Color.orange);
      }
      
      FontMetrics fm = g.getFontMetrics();
      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
      int t = (int)(m_y * h) - fm.getHeight() / 2;
      g.fill3DRect(l, t, fm.stringWidth(m_id) + 4
		   , fm.getHeight() + fm.getDescent() + 4
		   , true);
      g.setColor(Color.black);
      
      g.drawString(m_id, l + 2, t + fm.getHeight() + 2);

    }


    /**
     * Call this function to draw the node highlighted.
     * @param g The graphics context.
     * @param w The width of the drawing area.
     * @param h The height of the drawing area.
     */
    public void drawHighlight(Graphics g, int w, int h) {
      
      g.setColor(Color.black);
      FontMetrics fm = g.getFontMetrics();
      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
      int t = (int)(m_y * h) - fm.getHeight() / 2;
      g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8
		 , fm.getHeight() + fm.getDescent() + 8); 
      drawNode(g, w, h);
    }
    
    /**
     * Call this to get the output value of this unit. 
     * @param calculate True if the value should be calculated if it hasn't 
     * been already.
     * @return The output value, or NaN, if the value has not been calculated.
     */
    public double outputValue(boolean calculate) {
     
      if (Double.isNaN(m_unitValue) && calculate) {
	if (m_input) {
	  if (m_currentInstance.isMissing(m_link)) {
	    m_unitValue = 0;
	  }
	  else {
	    
	    m_unitValue = m_currentInstance.value(m_link);
	  }
	}
	else {
	  //node is an output.
	  m_unitValue = 0;
	  for (int noa = 0; noa < m_numInputs; noa++) {
	    m_unitValue += m_inputList[noa].outputValue(true);
	   
	  }
	  if (m_numeric && m_normalizeClass) {
	    //then scale the value;
	    //this scales linearly from between -1 and 1
	    m_unitValue = m_unitValue * 
	      m_attributeRanges[m_instances.classIndex()] + 
	      m_attributeBases[m_instances.classIndex()];
	  }
	}
      }
      return m_unitValue;
      
      
    }
    
    /**
     * Call this to get the error value of this unit, which in this case is
     * the difference between the predicted class, and the actual class.
     * @param calculate True if the value should be calculated if it hasn't 
     * been already.
     * @return The error value, or NaN, if the value has not been calculated.
     */
    public double errorValue(boolean calculate) {
      
      if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) 
	  && calculate) {
	
	if (m_input) {
	  m_unitError = 0;
	  for (int noa = 0; noa < m_numOutputs; noa++) {
	    m_unitError += m_outputList[noa].errorValue(true);
	  }
	}
	else {
	  if (m_currentInstance.classIsMissing()) {
	    m_unitError = .1;  
	  }
	  else if (m_instances.classAttribute().isNominal()) {
	    if (m_currentInstance.classValue() == m_link) {
	      m_unitError = 1 - m_unitValue;
	    }
	    else {
	      m_unitError = 0 - m_unitValue;
	    }
	  }
	  else if (m_numeric) {
	    
	    if (m_normalizeClass) {
	      if (m_attributeRanges[m_instances.classIndex()] == 0) {
		m_unitError = 0;
	      }
	      else {
		m_unitError = (m_currentInstance.classValue() - m_unitValue ) /
		  m_attributeRanges[m_instances.classIndex()];
		//m_numericRange;
		
	      }
	    }
	    else {
	      m_unitError = m_currentInstance.classValue() - m_unitValue;
	    }
	  }
	}
      }
      return m_unitError;
    }
    
    
    /**
     * Call this to reset the value and error for this unit, ready for the next
     * run. This will also call the reset function of all units that are 
     * connected as inputs to this one.
     * This is also the time that the update for the listeners will be 
     * performed.
     */
    public void reset() {
      
      if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
	m_unitValue = Double.NaN;
	m_unitError = Double.NaN;
	m_weightsUpdated = false;
	for (int noa = 0; noa < m_numInputs; noa++) {
	  m_inputList[noa].reset();
	}
      }
    }
    
    
    /** 
     * Call this function to set What this end unit represents.
     * @param input True if this unit is used for entering an attribute,
     * False if it's used for determining a class value.
     * @param val The attribute number or class type that this unit represents.
     * (for nominal attributes).
     */
    public void setLink(boolean input, int val) throws Exception {
      m_input = input;
      
      if (input) {
	m_type = PURE_INPUT;
      }
      else {
	m_type = PURE_OUTPUT;
      }
      if (val < 0 || (input && val > m_instances.numAttributes()) 
	  || (!input && m_instances.classAttribute().isNominal() 
	      && val > m_instances.classAttribute().numValues())) {
	m_link = 0;
      }
      else {
	m_link = val;
      }
    }
    
    /**
     * @return link for this node.
     */
    public int getLink() {
      return m_link;
    }
    

  }
  

 
  /** Inner class used to draw the nodes onto.(uses the node lists!!) 
   * This will also handle the user input. */
  private class NodePanel extends JPanel {
    


    /**
     * The constructor.
     */
    public NodePanel() {
      

      addMouseListener(new MouseAdapter() {
	  
	  public void mousePressed(MouseEvent e) {
	    
	    if (!m_stopped) {
	      return;
	    }
	    if ((e.getModifiers() & e.BUTTON1_MASK) == e.BUTTON1_MASK && 
		!e.isAltDown()) {
	      Graphics g = NodePanel.this.getGraphics();
	      int x = e.getX();
	      int y = e.getY();
	      int w = NodePanel.this.getWidth();
	      int h = NodePanel.this.getHeight();
	      int u = 0;
	      FastVector tmp = new FastVector(4);
	      for (int noa = 0; noa < m_numAttributes; noa++) {
		if (m_inputs[noa].onUnit(g, x, y, w, h)) {
		  tmp.addElement(m_inputs[noa]);
		  selection(tmp, 
			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
			    , true);
		  return;
		}
	      }
	      for (int noa = 0; noa < m_numClasses; noa++) {
		if (m_outputs[noa].onUnit(g, x, y, w, h)) {
		  tmp.addElement(m_outputs[noa]);
		  selection(tmp,
			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
			    , true);
		  return;
		}
	      }
	      for (int noa = 0; noa < m_neuralNodes.length; noa++) {
		if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
		  tmp.addElement(m_neuralNodes[noa]);
		  selection(tmp,
			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
			    , true);
		  return;
		}

	      }
	      NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), 
					       m_random, m_sigmoidUnit);
	      m_nextId++;
	      temp.setX((double)e.getX() / w);
	      temp.setY((double)e.getY() / h);
	      tmp.addElement(temp);
	      addNode(temp);
	      selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
			, true);
	    }
	    else {
	      //then right click
	      Graphics g = NodePanel.this.getGraphics();
	      int x = e.getX();
	      int y = e.getY();
	      int w = NodePanel.this.getWidth();
	      int h = NodePanel.this.getHeight();
	      int u = 0;
	      FastVector tmp = new FastVector(4);
	      for (int noa = 0; noa < m_numAttributes; noa++) {
		if (m_inputs[noa].onUnit(g, x, y, w, h)) {
		  tmp.addElement(m_inputs[noa]);
		  selection(tmp, 
			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
			    , false);
		  return;
		}
		
		
	      }
	      for (int noa = 0; noa < m_numClasses; noa++) {
		if (m_outputs[noa].onUnit(g, x, y, w, h)) {
		  tmp.addElement(m_outputs[noa]);
		  selection(tmp,
			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
			    , false);
		  return;
		}
	      }
	      for (int noa = 0; noa < m_neuralNodes.length; noa++) {
		if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
		  tmp.addElement(m_neuralNodes[noa]);
		  selection(tmp,
			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
			    , false);
		  return;
		}
	      }
	      selection(null, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
			, false);
	    }
	  }
	});
    }
    
    
    /**
     * This function gets called when the user has clicked something
     * It will amend the current selection or connect the current selection
     * to the new selection.
     * Or if nothing was selected and the right button was used it will 
     * delete the node.
     * @param v The units that were selected.
     * @param ctrl True if ctrl was held down.
     * @param left True if it was the left mouse button.
     */
    private void selection(FastVector v, boolean ctrl, boolean left) {
      
      if (v == null) {
	//then unselect all.
	m_selected.removeAllElements();
	repaint();
	return;
      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -