⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 multilayerperceptron.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
  }

  /**
   * A function used to stop the code that called buildclassifier
   * from continuing on before the user has finished the decision tree.
   * @param tf True to stop the thread, False to release the thread that is
   * waiting there (if one).
   */
  public synchronized void blocker(boolean tf) {
    if (tf) {
      try {
	wait();
      } catch(InterruptedException e) {
      }
    }
    else {
      notifyAll();
    }
  }

  /**
   * Call this function to update the control panel for the gui.
   */
  private void updateDisplay() {
    
    if (m_gui) {
      m_controlPanel.m_errorLabel.repaint();
      m_controlPanel.m_epochsLabel.repaint();
    }
  }
  

  /**
   * this will reset all the nodes in the network.
   */
  private void resetNetwork() {
    for (int noc = 0; noc < m_numClasses; noc++) {
      m_outputs[noc].reset();
    }
  }
  
  /**
   * This will cause the output values of all the nodes to be calculated.
   * Note that the m_currentInstance is used to calculate these values.
   */
  private void calculateOutputs() {
    for (int noc = 0; noc < m_numClasses; noc++) {	
      //get the values. 
      m_outputs[noc].outputValue(true);
    }
  }

  /**
   * This will cause the error values to be calculated for all nodes.
   * Note that the m_currentInstance is used to calculate these values.
   * Also the output values should have been calculated first.
   * @return The squared error.
   */
  private double calculateErrors() throws Exception {
    double ret = 0, temp = 0; 
    for (int noc = 0; noc < m_numAttributes; noc++) {
      //get the errors.
      m_inputs[noc].errorValue(true);
      
    }
    for (int noc = 0; noc < m_numClasses; noc++) {
      temp = m_outputs[noc].errorValue(false);
      ret += temp * temp;
    }    
    return ret;
    
  }

  /**
   * This will cause the weight values to be updated based on the learning
   * rate, momentum and the errors that have been calculated for each node.
   * @param l The learning rate to update with.
   * @param m The momentum to update with.
   */
  private void updateNetworkWeights(double l, double m) {
    for (int noc = 0; noc < m_numClasses; noc++) {
      //update weights
      m_outputs[noc].updateWeights(l, m);
    }

  }
  
  /**
   * This creates the required input units.
   */
  private void setupInputs() throws Exception {
    m_inputs = new NeuralEnd[m_numAttributes];
    int now = 0;
    for (int noa = 0; noa < m_numAttributes+1; noa++) {
      if (m_instances.classIndex() != noa) {
	m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name());
	
	m_inputs[noa - now].setX(.1);
	m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1));
	m_inputs[noa - now].setLink(true, noa);
      }    
      else {
	now = 1;
      }
    }

  }

  /**
   * This creates the required output units.
   */
  private void setupOutputs() throws Exception {
  
    m_outputs = new NeuralEnd[m_numClasses];
    for (int noa = 0; noa < m_numClasses; noa++) {
      if (m_numeric) {
	m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name());
      }
      else {
	m_outputs[noa]= new NeuralEnd(m_instances.classAttribute().value(noa));
      }
      
      m_outputs[noa].setX(.9);
      m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1));
      m_outputs[noa].setLink(false, noa);
      NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
				       m_sigmoidUnit);
      m_nextId++;
      temp.setX(.75);
      temp.setY((noa + 1.0) / (m_numClasses + 1));
      addNode(temp);
      NeuralConnection.connect(temp, m_outputs[noa]);
    }
 
  }
  
  /**
   * Call this function to automatically generate the hidden units
   */
  private void setupHiddenLayer()
  {
    StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ",");
    int val = 0;  //num of nodes in a layer
    int prev = 0; //used to remember the previous layer
    int num = tok.countTokens(); //number of layers
    String c;
    for (int noa = 0; noa < num; noa++) {
      //note that I am using the Double to get the value rather than the
      //Integer class, because for some reason the Double implementation can
      //handle leading white space and the integer version can't!?!
      c = tok.nextToken().trim();
      if (c.equals("a")) {
	val = (m_numAttributes + m_numClasses) / 2;
      }
      else if (c.equals("i")) {
	val = m_numAttributes;
      }
      else if (c.equals("o")) {
	val = m_numClasses;
      }
      else if (c.equals("t")) {
	val = m_numAttributes + m_numClasses;
      }
      else {
	val = Double.valueOf(c).intValue();
      }
      for (int nob = 0; nob < val; nob++) {
	NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
					 m_sigmoidUnit);
	m_nextId++;
	temp.setX(.5 / (num) * noa + .25);
	temp.setY((nob + 1.0) / (val + 1));
	addNode(temp);
	if (noa > 0) {
	  //then do connections
	  for (int noc = m_neuralNodes.length - nob - 1 - prev;
	       noc < m_neuralNodes.length - nob - 1; noc++) {
	    NeuralConnection.connect(m_neuralNodes[noc], temp);
	  }
	}
      }      
      prev = val;
    }
    tok = new StringTokenizer(m_hiddenLayers, ",");
    c = tok.nextToken();
    if (c.equals("a")) {
      val = (m_numAttributes + m_numClasses) / 2;
    }
    else if (c.equals("i")) {
      val = m_numAttributes;
    }
    else if (c.equals("o")) {
      val = m_numClasses;
    }
    else if (c.equals("t")) {
      val = m_numAttributes + m_numClasses;
    }
    else {
      val = Double.valueOf(c).intValue();
    }
    
    if (val == 0) {
      for (int noa = 0; noa < m_numAttributes; noa++) {
	for (int nob = 0; nob < m_numClasses; nob++) {
	  NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
	}
      }
    }
    else {
      for (int noa = 0; noa < m_numAttributes; noa++) {
	for (int nob = m_numClasses; nob < m_numClasses + val; nob++) {
	  NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
	}
      }
      for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length;
	   noa++) {
	for (int nob = 0; nob < m_numClasses; nob++) {
	  NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]);
	}
      }
    }
    
  }
  
  /**
   * This will go through all the nodes and check if they are connected
   * to a pure output unit. If so they will be set to be linear units.
   * If not they will be set to be sigmoid units.
   */
  private void setEndsToLinear() {
    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
      if ((m_neuralNodes[noa].getType() & NeuralConnection.OUTPUT) ==
	  NeuralConnection.OUTPUT) {
	((NeuralNode)m_neuralNodes[noa]).setMethod(m_linearUnit);
      }
      else {
	((NeuralNode)m_neuralNodes[noa]).setMethod(m_sigmoidUnit);
      }
    }
  }
  
  /**
   * Call this function to build and train a neural network for the training
   * data provided.
   * @param i The training data.
   * @exception Throws exception if can't build classification properly.
   */
  public void buildClassifier(Instances i) throws Exception {

    if (i.checkForStringAttributes()) {
      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
    }

    if (i.numInstances() == 0) {
      throw new IllegalArgumentException("No training instances.");
    }
    m_epoch = 0;
    m_error = 0;
    m_instances = null;
    m_currentInstance = null;
    m_controlPanel = null;
    m_nodePanel = null;
    
    
    m_outputs = new NeuralEnd[0];
    m_inputs = new NeuralEnd[0];
    m_numAttributes = 0;
    m_numClasses = 0;
    m_neuralNodes = new NeuralConnection[0];
    
    m_selected = new FastVector(4);
    m_graphers = new FastVector(2);
    m_nextId = 0;
    m_stopIt = true;
    m_stopped = true;
    m_accepted = false;    
    m_instances = new Instances(i);
    m_instances.deleteWithMissingClass();
    if (m_instances.numInstances() == 0) {
      m_instances = null;
      throw new IllegalArgumentException("All class values missing.");
    }
    m_random = new Random(m_randomSeed);
    m_instances.randomize(m_random);

    if (m_useNomToBin) {
      m_nominalToBinaryFilter = new NominalToBinary();
      m_nominalToBinaryFilter.setInputFormat(m_instances);
      m_instances = Filter.useFilter(m_instances,
				     m_nominalToBinaryFilter);
    }
    m_numAttributes = m_instances.numAttributes() - 1;
    m_numClasses = m_instances.numClasses();
 
    
    setClassType(m_instances);
    

   
    //this sets up the validation set.
    Instances valSet = null;
    //numinval is needed later
    int numInVal = (int)(m_valSize / 100.0 * m_instances.numInstances());
    if (m_valSize > 0) {
      if (numInVal == 0) {
	numInVal = 1;
      }
      valSet = new Instances(m_instances, 0, numInVal);
    }
    ///////////

    setupInputs();
      
    setupOutputs();    
    if (m_autoBuild) {
      setupHiddenLayer();
    }
    
    /////////////////////////////
    //this sets up the gui for usage
    if (m_gui) {
      m_win = new JFrame();
      
      m_win.addWindowListener(new WindowAdapter() {
	  public void windowClosing(WindowEvent e) {
	    boolean k = m_stopIt;
	    m_stopIt = true;
	    int well =JOptionPane.showConfirmDialog(m_win, 
						    "Are You Sure...\n"
						    + "Click Yes To Accept"
						    + " The Neural Network" 
						    + "\n Click No To Return",
						    "Accept Neural Network", 
						    JOptionPane.YES_NO_OPTION);
	    
	    if (well == 0) {
	      m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
	      m_accepted = true;
	      blocker(false);
	    }
	    else {
	      m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE);
	    }
	    m_stopIt = k;
	  }
	});
      
      m_win.getContentPane().setLayout(new BorderLayout());
      m_win.setTitle("Neural Network");
      m_nodePanel = new NodePanel();

      JScrollPane sp = new JScrollPane(m_nodePanel,
				       JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, 
				       JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);
      m_controlPanel = new ControlPanel();
           
      m_win.getContentPane().add(sp, BorderLayout.CENTER);
      m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH);
      m_win.setSize(640, 480);
      m_win.setVisible(true);
    }
   
    //This sets up the initial state of the gui
    if (m_gui) {
      blocker(true);
      m_controlPanel.m_changeEpochs.setEnabled(false);
      m_controlPanel.m_changeLearning.setEnabled(false);
      m_controlPanel.m_changeMomentum.setEnabled(false);
    } 
    
    //For silly situations in which the network gets accepted before training
    //commenses
    if (m_numeric) {
      setEndsToLinear();
    }
    if (m_accepted) {
      m_win.dispose();
      m_controlPanel = null;
      m_nodePanel = null;
      m_instances = new Instances(m_instances, 0);
      return;
    }

    //connections done.
    double right = 0;
    double driftOff = 0;
    double lastRight = Double.POSITIVE_INFINITY;
    double tempRate;
    double totalWeight = 0;
    double totalValWeight = 0;
    double origRate = m_learningRate; //only used for when reset
    
    //ensure that at least 1 instance is trained through.
    if (numInVal == m_instances.numInstances()) {
      numInVal--;
    }
    if (numInVal < 0) {
      numInVal = 0;
    }
    for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
      if (!m_instances.instance(noa).classIsMissing()) {
	totalWeight += m_instances.instance(noa).weight();
      }
    }
    if (m_valSize != 0) {
      for (int noa = 0; noa < valSet.numInstances(); noa++) {
	if (!valSet.instance(noa).classIsMissing()) {
	  totalValWeight += valSet.instance(noa).weight();
	}
      }
    }
    m_stopped = false;
     

    for (int noa = 1; noa < m_numEpochs + 1; noa++) {
      right = 0;
      for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
	m_currentInstance = m_instances.instance(nob);
	
	if (!m_currentInstance.classIsMissing()) {
	   
	  //this is where the network updating (and training occurs, for the
	  //training set
	  resetNetwork();
	  calculateOutputs();
	  tempRate = m_learningRate * m_currentInstance.weight();  
	  if (m_decay) {
	    tempRate /= noa;
	  }

	  right += (calculateErrors() / m_instances.numClasses()) *
	    m_currentInstance.weight();
	  updateNetworkWeights(tempRate, m_momentum);
	  
	}
	
      }
      right /= totalWeight;
      if (Double.isInfinite(right) || Double.isNaN(right)) {
	if (!m_reset) {
	  m_instances = null;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -