⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuralnetwork.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    NeuralConnection.java *    Copyright (C) 2000 Malcolm Ware */package weka.classifiers.neural;import java.util.*;import java.awt.*;import java.awt.event.*;import javax.swing.*;import weka.classifiers.*;import weka.core.*;import weka.filters.*;/**  * A Classifier that uses backpropagation to classify instances. * This network can be built by hand, created by an algorithm or both. * The network can also be monitored and modified during training time. * The nodes in this network are all sigmoid (except for when the class * is numeric in which case the the output nodes become unthresholded linear * units). * * @author Malcolm Ware (mfw4@cs.waikato.ac.nz) * @version $Revision: 1.4.2.2 $ */public class NeuralNetwork extends DistributionClassifier   implements OptionHandler, WeightedInstancesHandler {    /**   * Main method for testing this class.   *   * @param argv should contain command line options (see setOptions)   */  public static void main(String [] argv) {        try {      System.out.println(Evaluation.evaluateModel(new NeuralNetwork(), argv));    } catch (Exception e) {      System.err.println(e.getMessage());      e.printStackTrace();    }    System.exit(0);  }    /**    * This inner class is used to connect the nodes in the network up to   * the data that they are classifying, Note that objects of this class are   * only suitable to go on the attribute side or class side of the network   * and not both.   */  protected class NeuralEnd extends NeuralConnection {            /**      * the value that represents the instance value this node represents.      * For an input it is the attribute number, for an output, if nominal     * it is the class value.      */    private int m_link;        /** True if node is an input, False if it's an output. */    private boolean m_input;    public NeuralEnd(String id) {      super(id);      m_link = 0;      m_input = true;          }      /**     * Call this function to determine if the point at x,y is on the unit.     * @param g The graphics context for font size info.     * @param x The x coord.     * @param y The y coord.     * @param w The width of the display.     * @param h The height of the display.     * @return True if the point is on the unit, false otherwise.     */    public boolean onUnit(Graphics g, int x, int y, int w, int h) {            FontMetrics fm = g.getFontMetrics();      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;      int t = (int)(m_y * h) - fm.getHeight() / 2;      if (x < l || x > l + fm.stringWidth(m_id) + 4 	  || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) {	return false;      }      return true;          }       /**     * This will draw the node id to the graphics context.     * @param g The graphics context.     * @param w The width of the drawing area.     * @param h The height of the drawing area.     */    public void drawNode(Graphics g, int w, int h) {            if ((m_type & PURE_INPUT) == PURE_INPUT) {	g.setColor(Color.green);      }      else {	g.setColor(Color.orange);      }            FontMetrics fm = g.getFontMetrics();      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;      int t = (int)(m_y * h) - fm.getHeight() / 2;      g.fill3DRect(l, t, fm.stringWidth(m_id) + 4		   , fm.getHeight() + fm.getDescent() + 4		   , true);      g.setColor(Color.black);            g.drawString(m_id, l + 2, t + fm.getHeight() + 2);    }    /**     * Call this function to draw the node highlighted.     * @param g The graphics context.     * @param w The width of the drawing area.     * @param h The height of the drawing area.     */    public void drawHighlight(Graphics g, int w, int h) {            g.setColor(Color.black);      FontMetrics fm = g.getFontMetrics();      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;      int t = (int)(m_y * h) - fm.getHeight() / 2;      g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8		 , fm.getHeight() + fm.getDescent() + 8);       drawNode(g, w, h);    }        /**     * Call this to get the output value of this unit.      * @param calculate True if the value should be calculated if it hasn't      * been already.     * @return The output value, or NaN, if the value has not been calculated.     */    public double outputValue(boolean calculate) {           if (Double.isNaN(m_unitValue) && calculate) {	if (m_input) {	  if (m_currentInstance.isMissing(m_link)) {	    m_unitValue = 0;	  }	  else {	    	    m_unitValue = m_currentInstance.value(m_link);	  }	}	else {	  //node is an output.	  m_unitValue = 0;	  for (int noa = 0; noa < m_numInputs; noa++) {	    m_unitValue += m_inputList[noa].outputValue(true);	   	  }	  if (m_numeric && m_normalizeClass) {	    //then scale the value;	    //this scales linearly from between -1 and 1	    m_unitValue = m_unitValue * 	      m_attributeRanges[m_instances.classIndex()] + 	      m_attributeBases[m_instances.classIndex()];	  }	}      }      return m_unitValue;                }        /**     * Call this to get the error value of this unit, which in this case is     * the difference between the predicted class, and the actual class.     * @param calculate True if the value should be calculated if it hasn't      * been already.     * @return The error value, or NaN, if the value has not been calculated.     */    public double errorValue(boolean calculate) {            if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) 	  && calculate) {		if (m_input) {	  m_unitError = 0;	  for (int noa = 0; noa < m_numOutputs; noa++) {	    m_unitError += m_outputList[noa].errorValue(true);	  }	}	else {	  if (m_currentInstance.classIsMissing()) {	    m_unitError = .1;  	  }	  else if (m_instances.classAttribute().isNominal()) {	    if (m_currentInstance.classValue() == m_link) {	      m_unitError = 1 - m_unitValue;	    }	    else {	      m_unitError = 0 - m_unitValue;	    }	  }	  else if (m_numeric) {	    	    if (m_normalizeClass) {	      if (m_attributeRanges[m_instances.classIndex()] == 0) {		m_unitError = 0;	      }	      else {		m_unitError = (m_currentInstance.classValue() - m_unitValue ) /		  m_attributeRanges[m_instances.classIndex()];		//m_numericRange;			      }	    }	    else {	      m_unitError = m_currentInstance.classValue() - m_unitValue;	    }	  }	}      }      return m_unitError;    }            /**     * Call this to reset the value and error for this unit, ready for the next     * run. This will also call the reset function of all units that are      * connected as inputs to this one.     * This is also the time that the update for the listeners will be      * performed.     */    public void reset() {            if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {	m_unitValue = Double.NaN;	m_unitError = Double.NaN;	m_weightsUpdated = false;	for (int noa = 0; noa < m_numInputs; noa++) {	  m_inputList[noa].reset();	}      }    }            /**      * Call this function to set What this end unit represents.     * @param input True if this unit is used for entering an attribute,     * False if it's used for determining a class value.     * @param val The attribute number or class type that this unit represents.     * (for nominal attributes).     */    public void setLink(boolean input, int val) throws Exception {      m_input = input;            if (input) {	m_type = PURE_INPUT;      }      else {	m_type = PURE_OUTPUT;      }      if (val < 0 || (input && val > m_instances.numAttributes()) 	  || (!input && m_instances.classAttribute().isNominal() 	      && val > m_instances.classAttribute().numValues())) {	m_link = 0;      }      else {	m_link = val;      }    }        /**     * @return link for this node.     */    public int getLink() {      return m_link;    }      }     /** Inner class used to draw the nodes onto.(uses the node lists!!)    * This will also handle the user input. */  private class NodePanel extends JPanel {        /**     * The constructor.     */    public NodePanel() {            addMouseListener(new MouseAdapter() {	  	  public void mousePressed(MouseEvent e) {	    	    if (!m_stopped) {	      return;	    }	    if ((e.getModifiers() & e.BUTTON1_MASK) == e.BUTTON1_MASK) {	      Graphics g = NodePanel.this.getGraphics();	      int x = e.getX();	      int y = e.getY();	      int w = NodePanel.this.getWidth();	      int h = NodePanel.this.getHeight();	      int u = 0;	      FastVector tmp = new FastVector(4);	      for (int noa = 0; noa < m_numAttributes; noa++) {		if (m_inputs[noa].onUnit(g, x, y, w, h)) {		  tmp.addElement(m_inputs[noa]);		  selection(tmp, 			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK			    , true);		  return;		}	      }	      for (int noa = 0; noa < m_numClasses; noa++) {		if (m_outputs[noa].onUnit(g, x, y, w, h)) {		  tmp.addElement(m_outputs[noa]);		  selection(tmp,			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK			    , true);		  return;		}	      }	      for (int noa = 0; noa < m_neuralNodes.length; noa++) {		if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {		  tmp.addElement(m_neuralNodes[noa]);		  selection(tmp,			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK			    , true);		  return;		}	      }	      NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), 					       m_random, m_sigmoidUnit);	      m_nextId++;	      temp.setX((double)e.getX() / w);	      temp.setY((double)e.getY() / h);	      tmp.addElement(temp);	      addNode(temp);	      selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK			, true);	    }	    else {	      //then right click	      Graphics g = NodePanel.this.getGraphics();	      int x = e.getX();	      int y = e.getY();	      int w = NodePanel.this.getWidth();	      int h = NodePanel.this.getHeight();	      int u = 0;	      FastVector tmp = new FastVector(4);	      for (int noa = 0; noa < m_numAttributes; noa++) {		if (m_inputs[noa].onUnit(g, x, y, w, h)) {		  tmp.addElement(m_inputs[noa]);		  selection(tmp, 			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK			    , false);		  return;		}					      }	      for (int noa = 0; noa < m_numClasses; noa++) {		if (m_outputs[noa].onUnit(g, x, y, w, h)) {		  tmp.addElement(m_outputs[noa]);		  selection(tmp,			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK			    , false);		  return;		}	      }	      for (int noa = 0; noa < m_neuralNodes.length; noa++) {		if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {		  tmp.addElement(m_neuralNodes[noa]);		  selection(tmp,			    (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK			    , false);		  return;		}	      }	      selection(null, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK			, false);	    }	  }	});    }            /**     * This function gets called when the user has clicked something     * It will amend the current selection or connect the current selection     * to the new selection.     * Or if nothing was selected and the right button was used it will      * delete the node.     * @param v The units that were selected.     * @param ctrl True if ctrl was held down.     * @param left True if it was the left mouse button.     */    private void selection(FastVector v, boolean ctrl, boolean left) {            if (v == null) {	//then unselect all.	m_selected.removeAllElements();	repaint();	return;      }            //then exclusive or the new selection with the current one.      if ((ctrl || m_selected.size() == 0) && left) {	boolean removed = false;	for (int noa = 0; noa < v.size(); noa++) {	  removed = false;	  for (int nob = 0; nob < m_selected.size(); nob++) {	    if (v.elementAt(noa) == m_selected.elementAt(nob)) {	      //then remove that element	      m_selected.removeElementAt(nob);	      removed = true;	      break;	    }	  }	  if (!removed) {	    m_selected.addElement(v.elementAt(noa));

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -