⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 multilayerperceptron.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
	  throw new Exception("Network cannot train. Try restarting with a" +
			      " smaller learning rate.");
	}
	else {
	  //reset the network
	  m_learningRate /= 2;
	  buildClassifier(i);
	  m_learningRate = origRate;
	  m_instances = new Instances(m_instances, 0);	  
	  return;
	}
      }

      ////////////////////////do validation testing if applicable
      if (m_valSize != 0) {
	right = 0;
	for (int nob = 0; nob < valSet.numInstances(); nob++) {
	  m_currentInstance = valSet.instance(nob);
	  if (!m_currentInstance.classIsMissing()) {
	    //this is where the network updating occurs, for the validation set
	    resetNetwork();
	    calculateOutputs();
	    right += (calculateErrors() / valSet.numClasses()) 
	      * m_currentInstance.weight();
	    //note 'right' could be calculated here just using
	    //the calculate output values. This would be faster.
	    //be less modular
	  }
	  
	}
	
	if (right < lastRight) {
	  driftOff = 0;
	}
	else {
	  driftOff++;
	}
	lastRight = right;
	if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
	  m_accepted = true;
	}
	right /= totalValWeight;
      }
      m_epoch = noa;
      m_error = right;
      //shows what the neuralnet is upto if a gui exists. 
      updateDisplay();
      //This junction controls what state the gui is in at the end of each
      //epoch, Such as if it is paused, if it is resumable etc...
      if (m_gui) {
	while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && 
		!m_accepted) {
	  m_stopIt = true;
	  m_stopped = true;
	  if (m_epoch >= m_numEpochs && m_valSize == 0) {
	    
	    m_controlPanel.m_startStop.setEnabled(false);
	  }
	  else {
	    m_controlPanel.m_startStop.setEnabled(true);
	  }
	  m_controlPanel.m_startStop.setText("Start");
	  m_controlPanel.m_startStop.setActionCommand("Start");
	  m_controlPanel.m_changeEpochs.setEnabled(true);
	  m_controlPanel.m_changeLearning.setEnabled(true);
	  m_controlPanel.m_changeMomentum.setEnabled(true);
	  
	  blocker(true);
	  if (m_numeric) {
	    setEndsToLinear();
	  }
	}
	m_controlPanel.m_changeEpochs.setEnabled(false);
	m_controlPanel.m_changeLearning.setEnabled(false);
	m_controlPanel.m_changeMomentum.setEnabled(false);
	
	m_stopped = false;
	//if the network has been accepted stop the training loop
	if (m_accepted) {
	  m_win.dispose();
	  m_controlPanel = null;
	  m_nodePanel = null;
	  m_instances = new Instances(m_instances, 0);
	  return;
	}
      }
      if (m_accepted) {
	m_instances = new Instances(m_instances, 0);
	return;
      }
    }
    if (m_gui) {
      m_win.dispose();
      m_controlPanel = null;
      m_nodePanel = null;
    }
    m_instances = new Instances(m_instances, 0);  
  }

  /**
   * Call this function to predict the class of an instance once a 
   * classification model has been built with the buildClassifier call.
   * @param i The instance to classify.
   * @return A double array filled with the probabilities of each class type.
   * @exception if can't classify instance.
   */
  public double[] distributionForInstance(Instance i) throws Exception {

    
    if (m_useNomToBin) {
      m_nominalToBinaryFilter.input(i);
      m_currentInstance = m_nominalToBinaryFilter.output();
    }
    else {
      m_currentInstance = i;
    }
    
    if (m_normalizeAttributes) {
      for (int noa = 0; noa < m_instances.numAttributes(); noa++) {
	if (noa != m_instances.classIndex()) {
	  if (m_attributeRanges[noa] != 0) {
	    m_currentInstance.setValue(noa, (m_currentInstance.value(noa) - 
					     m_attributeBases[noa]) / 
				       m_attributeRanges[noa]);
	  }
	  else {
	    m_currentInstance.setValue(noa, m_currentInstance.value(noa) -
				       m_attributeBases[noa]);
	  }
	}
      }
    }
    resetNetwork();
    
    //since all the output values are needed.
    //They are calculated manually here and the values collected.
    double[] theArray = new double[m_numClasses];
    for (int noa = 0; noa < m_numClasses; noa++) {
      theArray[noa] = m_outputs[noa].outputValue(true);
    }
    if (m_instances.classAttribute().isNumeric()) {
      return theArray;
    }
    
    //now normalize the array
    double count = 0;
    for (int noa = 0; noa < m_numClasses; noa++) {
      count += theArray[noa];
    }
    if (count <= 0) {
      return null;
    }
    for (int noa = 0; noa < m_numClasses; noa++) {
      theArray[noa] /= count;
    }
    return theArray;
  }
  


  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {
    
    Vector newVector = new Vector(14);

    newVector.addElement(new Option(
	      "\tLearning Rate for the backpropagation algorithm.\n"
	      +"\t(Value should be between 0 - 1, Default = 0.3).",
	      "L", 1, "-L <learning rate>"));
    newVector.addElement(new Option(
	      "\tMomentum Rate for the backpropagation algorithm.\n"
	      +"\t(Value should be between 0 - 1, Default = 0.2).",
	      "M", 1, "-M <momentum>"));
    newVector.addElement(new Option(
	      "\tNumber of epochs to train through.\n"
	      +"\t(Default = 500).",
	      "N", 1,"-N <number of epochs>"));
    newVector.addElement(new Option(
	      "\tPercentage size of validation set to use to terminate" +
	      " training (if this is non zero it can pre-empt num of epochs.\n"
	      +"\t(Value should be between 0 - 100, Default = 0).",
	      "V", 1, "-V <percentage size of validation set>"));
    newVector.addElement(new Option(
	      "\tThe value used to seed the random number generator" +
	      "\t(Value should be >= 0 and and a long, Default = 0).",
	      "S", 1, "-S <seed>"));
    newVector.addElement(new Option(
	      "\tThe consequetive number of errors allowed for validation" +
	      " testing before the netwrok terminates." +
	      "\t(Value should be > 0, Default = 20).",
	      "E", 1, "-E <threshold for number of consequetive errors>"));
    newVector.addElement(new Option(
              "\tGUI will be opened.\n"
	      +"\t(Use this to bring up a GUI).",
	      "G", 0,"-G"));
    newVector.addElement(new Option(
              "\tAutocreation of the network connections will NOT be done.\n"
	      +"\t(This will be ignored if -G is NOT set)",
	      "A", 0,"-A"));
    newVector.addElement(new Option(
              "\tA NominalToBinary filter will NOT automatically be used.\n"
	      +"\t(Set this to not use a NominalToBinary filter).",
	      "B", 0,"-B"));
    newVector.addElement(new Option(
	      "\tThe hidden layers to be created for the network.\n"
	      +"\t(Value should be a list of comma seperated Natural numbers" +
	      " or the letters 'a' = (attribs + classes) / 2, 'i'" +
	      " = attribs, 'o' = classes, 't' = attribs .+ classes)" +
	      " For wildcard values" +
	      ",Default = a).",
	      "H", 1, "-H <comma seperated numbers for nodes on each layer>"));
    newVector.addElement(new Option(
              "\tNormalizing a numeric class will NOT be done.\n"
	      +"\t(Set this to not normalize the class if it's numeric).",
	      "C", 0,"-C"));
    newVector.addElement(new Option(
              "\tNormalizing the attributes will NOT be done.\n"
	      +"\t(Set this to not normalize the attributes).",
	      "I", 0,"-I"));
    newVector.addElement(new Option(
              "\tReseting the network will NOT be allowed.\n"
	      +"\t(Set this to not allow the network to reset).",
	      "R", 0,"-R"));
    newVector.addElement(new Option(
              "\tLearning rate decay will occur.\n"
	      +"\t(Set this to cause the learning rate to decay).",
	      "D", 0,"-D"));
    
    
    return newVector.elements();
  }

  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * -L num <br>
   * Set the learning rate.
   * (default 0.3) <p>
   *
   * -M num <br>
   * Set the momentum
   * (default 0.2) <p>
   *
   * -N num <br>
   * Set the number of epochs to train through.
   * (default 500) <p>
   *
   * -V num <br>
   * Set the percentage size of the validation set from the training to use.
   * (default 0 (no validation set is used, instead num of epochs is used) <p>
   *
   * -S num <br>
   * Set the seed for the random number generator.
   * (default 0) <p>
   *
   * -E num <br>
   * Set the threshold for the number of consequetive errors allowed during
   * validation testing.
   * (default 20) <p>
   *
   * -G <br>
   * Bring up a GUI for the neural net.
   * <p>
   *
   * -A <br>
   * Do not automatically create the connections in the net.
   * (can only be used if -G is specified) <p>
   *
   * -B <br>
   * Do Not automatically Preprocess the instances with a nominal to binary 
   * filter. <p>
   *
   * -H str <br>
   * Set the number of nodes to be used on each layer. Each number represents
   * its own layer and the num of nodes on that layer. Each number should be
   * comma seperated. There are also the wildcards 'a', 'i', 'o', 't'
   * (default 4) <p>
   *
   * -C <br>
   * Do not automatically Normalize the class if it's numeric. <p>
   *
   * -I <br>
   * Do not automatically Normalize the attributes. <p>
   *
   * -R <br>
   * Do not allow the network to be automatically reset. <p>
   *
   * -D <br>
   * Cause the learning rate to decay as training is done. <p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    //the defaults can be found here!!!!
    String learningString = Utils.getOption('L', options);
    if (learningString.length() != 0) {
      setLearningRate((new Double(learningString)).doubleValue());
    } else {
      setLearningRate(0.3);
    }
    String momentumString = Utils.getOption('M', options);
    if (momentumString.length() != 0) {
      setMomentum((new Double(momentumString)).doubleValue());
    } else {
      setMomentum(0.2);
    }
    String epochsString = Utils.getOption('N', options);
    if (epochsString.length() != 0) {
      setTrainingTime(Integer.parseInt(epochsString));
    } else {
      setTrainingTime(500);
    }
    String valSizeString = Utils.getOption('V', options);
    if (valSizeString.length() != 0) {
      setValidationSetSize(Integer.parseInt(valSizeString));
    } else {
      setValidationSetSize(0);
    }
    String seedString = Utils.getOption('S', options);
    if (seedString.length() != 0) {
      setRandomSeed(Long.parseLong(seedString));
    } else {
      setRandomSeed(0);
    }
    String thresholdString = Utils.getOption('E', options);
    if (thresholdString.length() != 0) {
      setValidationThreshold(Integer.parseInt(thresholdString));
    } else {
      setValidationThreshold(20);
    }
    String hiddenLayers = Utils.getOption('H', options);
    if (hiddenLayers.length() != 0) {
      setHiddenLayers(hiddenLayers);
    } else {
      setHiddenLayers("a");
    }
    if (Utils.getFlag('G', options)) {
      setGUI(true);
    } else {
      setGUI(false);
    } //small note. since the gui is the only option that can change the other
    //options this should be set first to allow the other options to set 
    //properly
    if (Utils.getFlag('A', options)) {
      setAutoBuild(false);
    } else {
      setAutoBuild(true);
    }
    if (Utils.getFlag('B', options)) {
      setNominalToBinaryFilter(false);
    } else {
      setNominalToBinaryFilter(true);
    }
    if (Utils.getFlag('C', options)) {
      setNormalizeNumericClass(false);
    } else {
      setNormalizeNumericClass(true);
    }
    if (Utils.getFlag('I', options)) {
      setNormalizeAttributes(false);
    } else {
      setNormalizeAttributes(true);
    }
    if (Utils.getFlag('R', options)) {
      setReset(false);
    } else {
      setReset(true);
    }
    if (Utils.getFlag('D', options)) {
      setDecay(true);
    } else {
      setDecay(false);
    }
    
    Utils.checkForRemainingOptions(options);
  }
  
  /**
   * Gets the current settings of NeuralNet.
   *
   * @return an array of strings suitable for passing to setOptions()
   */
  public String [] getOptions() {

    String [] options = new String [21];
    int current = 0;
    options[current++] = "-L"; options[current++] = "" + getLearningRate(); 
    options[current++] = "-M"; options[current++] = "" + getMomentum();
    options[current++] = "-N"; options[current++] = "" + getTrainingTime(); 
    options[current++] = "-V"; options[current++] = "" +getValidationSetSize();
    options[current++] = "-S"; options[current++] = "" + getRandomSeed();
    options[current++] = "-E"; options[current++] =""+getValidationThreshold();
    options[current++] = "-H"; options[current++] = getHiddenLayers();
    if (getGUI()) {
      options[current++] = "-G";
    }
    if (!getAutoBuild()) {
      options[current++] = "-A";
    }
    if (!getNominalToBinaryFilter()) {
      options[current++] = "-B";
    }
    if (!getNormalizeNumericClass()) {
      options[current++] = "-C";
    }
    if (!getNormalizeAttributes()) {
      options[current++] = "-I";
    }
    if (!getReset()) {
      options[current++] = "-R";
    }
    if (getDecay()) {
      options[current++] = "-D";
    }

    
    while (current < options.length) {
      options[current++] = "";
    }
    return options;
  }
  
  /**
   * @return string describing the model.
   */
  public String toString() {
    StringBuffer model = new StringBuffer(m_neuralNodes.length * 100); 
    //just a rough size guess
    NeuralNode con;
    double[] weights;
    NeuralConnection[] inputs;
    for (int noa = 0; noa < m_neuralNodes.len

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -