⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cvparameterselection.java

📁 为了下东西 随便发了个 datamining 的源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:

  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * -D <br>
   * Turn on debugging output.<p>
   *
   * -W classname <br>
   * Specify the full class name of classifier to perform cross-validation
   * selection on.<p>
   *
   * -X num <br>
   * Number of folds used for cross validation (default 10). <p>
   *
   * -S seed <br>
   * Random number seed (default 1).<p>
   *
   * -P "N 1 5 10" <br>
   * Sets an optimisation parameter for the classifier with name -N,
   * lower bound 1, upper bound 5, and 10 optimisation steps.
   * The upper bound may be the character 'A' or 'I' to substitute 
   * the number of attributes or instances in the training data,
   * respectively.
   * This parameter may be supplied more than once to optimise over
   * several classifier options simultaneously. <p>
   *
   * Options after -- are passed to the designated sub-classifier. <p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {

    String foldsString = Utils.getOption('X', options);
    if (foldsString.length() != 0) {
      setNumFolds(Integer.parseInt(foldsString));
    } else {
      setNumFolds(10);
    }

    String cvParam;
    m_CVParams = new FastVector();
    do {
      cvParam = Utils.getOption('P', options);
      if (cvParam.length() != 0) {
	addCVParameter(cvParam);
      }
    } while (cvParam.length() != 0);

    super.setOptions(options);
  }

  /**
   * Gets the current settings of the Classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String [] getOptions() {

    String[] superOptions;

    if (m_InitOptions != null) {
      try {
	m_Classifier.setOptions((String[])m_InitOptions.clone());
	superOptions = super.getOptions();
	m_Classifier.setOptions((String[])m_BestClassifierOptions.clone());
      } catch (Exception e) {
	throw new RuntimeException("CVParameterSelection: could not set options " +
				   "in getOptions().");
      } 
    } else {
      superOptions = super.getOptions();
    }
    String [] options = new String [superOptions.length + m_CVParams.size() * 2 + 2];

    int current = 0;
    for (int i = 0; i < m_CVParams.size(); i++) {
      options[current++] = "-P"; options[current++] = "" + getCVParameter(i);
    }
    options[current++] = "-X"; options[current++] = "" + getNumFolds();

    System.arraycopy(superOptions, 0, options, current, 
		     superOptions.length);

    return options;
  }

  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data 
   * @exception Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    if (instances.checkForStringAttributes()) {
      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
    }
    Instances trainData = new Instances(instances);
    trainData.deleteWithMissingClass();
    if (trainData.numInstances() == 0) {
      throw new IllegalArgumentException("No training instances without " +
					 "missing class.");
    }
    if (trainData.numInstances() < m_NumFolds) {
      throw new IllegalArgumentException("Number of training instances " +
					 "smaller than number of folds.");
    }
    if (!(m_Classifier instanceof OptionHandler)) {
      throw new IllegalArgumentException("Base classifier should be OptionHandler.");
    }
    m_InitOptions = ((OptionHandler)m_Classifier).getOptions();
    m_BestPerformance = -99;
    m_NumAttributes = trainData.numAttributes();
    Random random = new Random(m_Seed);
    trainData.randomize(random);
    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0, random).numInstances();

    // Check whether there are any parameters to optimize
    if (m_CVParams.size() == 0) {
       m_Classifier.buildClassifier(trainData);
       m_BestClassifierOptions = m_InitOptions;
       return;
    }

    if (trainData.classAttribute().isNominal()) {
      trainData.stratify(m_NumFolds);
    }
    m_BestClassifierOptions = null;
    
    // Set up m_ClassifierOptions -- take getOptions() and remove
    // those being optimised.
    m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions();
    for (int i = 0; i < m_CVParams.size(); i++) {
      Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar,
		      m_ClassifierOptions);
    }
    findParamsByCrossValidation(0, trainData, random);

    String [] options = (String [])m_BestClassifierOptions.clone();
    ((OptionHandler)m_Classifier).setOptions(options);
    m_Classifier.buildClassifier(trainData);
  }


  /**
   * Predicts the class distribution for the given test instance.
   *
   * @param instance the instance to be classified
   * @return the predicted class value
   * @exception Exception if an error occurred during the prediction
   */
  public double[] distributionForInstance(Instance instance) throws Exception {
    
    return m_Classifier.distributionForInstance(instance);
  }

  /**
   * Adds a scheme parameter to the list of parameters to be set
   * by cross-validation
   *
   * @param cvParam the string representation of a scheme parameter. The
   * format is: <br>
   * param_char lower_bound upper_bound increment <br>
   * eg to search a parameter -P from 1 to 10 by increments of 2: <br>
   * P 1 10 2 <br>
   * @exception Exception if the parameter specifier is of the wrong format
   */
  public void addCVParameter(String cvParam) throws Exception {

    CVParameter newCV = new CVParameter(cvParam);
    
    m_CVParams.addElement(newCV);
  }

  /**
   * Gets the scheme paramter with the given index.
   */
  public String getCVParameter(int index) {

    if (m_CVParams.size() <= index) {
      return "";
    }
    return ((CVParameter)m_CVParams.elementAt(index)).toString();
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String CVParametersTipText() {
    return "Sets the scheme parameters which are to be set "+
	   "by cross-validation.\n"+
	   "The format for each string should be:\n"+
	   "param_char lower_bound upper_bound increment\n"+
	   "eg to search a parameter -P from 1 to 10 by increments of 2:\n"+
	   "    \"P 1 10 2\" ";
  }

  /**
   * Get method for CVParameters.
   */
  public Object[] getCVParameters() {
      
      Object[] CVParams = m_CVParams.toArray();
      
      String params[] = new String[CVParams.length];
      
      for(int i=0; i<CVParams.length; i++) 
          params[i] = CVParams[i].toString();
      
      return params;
      
  }
  
  /**
   * Set method for CVParameters.
   */
  public void setCVParameters(Object[] params) throws Exception {
      
      FastVector backup = m_CVParams;
      m_CVParams = new FastVector();
      
      for(int i=0; i<params.length; i++) {
          try{
          addCVParameter((String)params[i]);
          }
          catch(Exception ex) { m_CVParams = backup; throw ex; }
      }
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String numFoldsTipText() {
    return "Get the number of folds used for cross-validation.";
  }

  /** 
   * Gets the number of folds for the cross-validation.
   *
   * @return the number of folds for the cross-validation
   */
  public int getNumFolds() {

    return m_NumFolds;
  }

  /**
   * Sets the number of folds for the cross-validation.
   *
   * @param numFolds the number of folds for the cross-validation
   * @exception Exception if parameter illegal
   */
  public void setNumFolds(int numFolds) throws Exception {
    
    if (numFolds < 0) {
      throw new IllegalArgumentException("Stacking: Number of cross-validation " +
					 "folds must be positive.");
    }
    m_NumFolds = numFolds;
  }
 
  /**
   *  Returns the type of graph this classifier
   *  represents.
   */   
  public int graphType() {
    
    if (m_Classifier instanceof Drawable)
      return ((Drawable)m_Classifier).graphType();
    else 
      return Drawable.NOT_DRAWABLE;
  }

  /**
   * Returns graph describing the classifier (if possible).
   *
   * @return the graph of the classifier in dotty format
   * @exception Exception if the classifier cannot be graphed
   */
  public String graph() throws Exception {
    
    if (m_Classifier instanceof Drawable)
      return ((Drawable)m_Classifier).graph();
    else throw new Exception("Classifier: " + 
			     m_Classifier.getClass().getName() + " " +
			     Utils.joinOptions(m_BestClassifierOptions)
			     + " cannot be graphed");
  }

  /**
   * Returns description of the cross-validated classifier.
   *
   * @return description of the cross-validated classifier as a string
   */
  public String toString() {

    if (m_InitOptions == null)
      return "CVParameterSelection: No model built yet.";

    String result = "Cross-validated Parameter selection.\n"
    + "Classifier: " + m_Classifier.getClass().getName() + "\n";
    try {
      for (int i = 0; i < m_CVParams.size(); i++) {
	CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
	result += "Cross-validation Parameter: '-" 
	  + cvParam.m_ParamChar + "'"
	  + " ranged from " + cvParam.m_Lower 
	  + " to ";
	switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
	case 1:
	  result += m_NumAttributes;
	  break;
	case 2:
	  result += m_TrainFoldSize;
	  break;
	default:
	  result += cvParam.m_Upper;
	  break;
	}
	result += " with " + cvParam.m_Steps + " steps\n";
      }
    } catch (Exception ex) {
      result += ex.getMessage();
    }
    result += "Classifier Options: "
      + Utils.joinOptions(m_BestClassifierOptions)
      + "\n\n" + m_Classifier.toString();
    return result;
  }

  public String toSummaryString() {

    String result = "Selected values: "
      + Utils.joinOptions(m_BestClassifierOptions);
    return result + '\n';
  }
  
  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String [] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new CVParameterSelection(), 
						  argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
}


  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -