⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 wrappersubseteval.java

📁 一个数据挖掘系统的源码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
   * Set the seed to use for cross validation
   *
   * @param s the seed
   */
  public void setSeed (int s) {
    m_seed = s;
  }


  /**
   * Get the random number seed used for cross validation
   *
   * @return the seed
   */
  public int getSeed () {
    return  m_seed;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String classifierTipText() {
    return "Classifier to use for estimating the accuracy of subsets";
  }

  /**
   * Set the classifier to use for accuracy estimation
   *
   * @param newClassifier the Classifier to use.
   */
  public void setClassifier (Classifier newClassifier) {
    m_BaseClassifier = newClassifier;
  }


  /**
   * Get the classifier used as the base learner.
   *
   * @return the classifier used as the classifier
   */
  public Classifier getClassifier () {
    return  m_BaseClassifier;
  }


  /**
   * Gets the current settings of WrapperSubsetEval.
   *
   * @return an array of strings suitable for passing to setOptions()
   */
  public String[] getOptions () {
    String[] classifierOptions = new String[0];

    if ((m_BaseClassifier != null) &&
	(m_BaseClassifier instanceof OptionHandler)) {
      classifierOptions = ((OptionHandler)m_BaseClassifier).getOptions();
    }

    String[] options = new String[9 + classifierOptions.length];
    int current = 0;

    if (getClassifier() != null) {
      options[current++] = "-B";
      options[current++] = getClassifier().getClass().getName();
    }

    options[current++] = "-F";
    options[current++] = "" + getFolds();
    options[current++] = "-T";
    options[current++] = "" + getThreshold();
    options[current++] = "-S";
    options[current++] = "" + getSeed();
    options[current++] = "--";
    System.arraycopy(classifierOptions, 0, options, current,
		     classifierOptions.length);
    current += classifierOptions.length;

    while (current < options.length) {
      options[current++] = "";
    }

    return  options;
  }


  protected void resetOptions () {
    m_trainInstances = null;
    m_Evaluation = null;
    m_BaseClassifier = new ZeroR();
    m_folds = 5;
    m_seed = 1;
    m_threshold = 0.01;
  }


  /**
   * Generates a attribute evaluator. Has to initialize all fields of the
   * evaluator that are not being set via options.
   *
   * @param data set of instances serving as training data
   * @exception Exception if the evaluator has not been
   * generated successfully
   */
  public void buildEvaluator (Instances data)
    throws Exception
  {
    if (data.checkForStringAttributes()) {
      throw  new Exception("Can't handle string attributes!");
    }

    m_trainInstances = data;
    m_classIndex = m_trainInstances.classIndex();
    m_numAttribs = m_trainInstances.numAttributes();
    m_numInstances = m_trainInstances.numInstances();
  }


  /**
   * Evaluates a subset of attributes
   *
   * @param subset a bitset representing the attribute subset to be
   * evaluated
   * @exception Exception if the subset could not be evaluated
   */
  public double evaluateSubset (BitSet subset)
    throws Exception
  {
    double errorRate = 0;
    double[] repError = new double[5];
    int numAttributes = 0;
    int i, j;
    Random Rnd = new Random(m_seed);
    AttributeFilter delTransform = new AttributeFilter();
    delTransform.setInvertSelection(true);
    // copy the instances
    Instances trainCopy = new Instances(m_trainInstances);

    // count attributes set in the BitSet
    for (i = 0; i < m_numAttribs; i++) {
      if (subset.get(i)) {
        numAttributes++;
      }
    }

    // set up an array of attribute indexes for the filter (+1 for the class)
    int[] featArray = new int[numAttributes + 1];

    for (i = 0, j = 0; i < m_numAttribs; i++) {
      if (subset.get(i)) {
        featArray[j++] = i;
      }
    }

    featArray[j] = m_classIndex;
    delTransform.setAttributeIndicesArray(featArray);
    delTransform.setInputFormat(trainCopy);
    trainCopy = Filter.useFilter(trainCopy, delTransform);

    // max of 5 repititions ofcross validation
    for (i = 0; i < 5; i++) {
      trainCopy.randomize(Rnd); // randomize instances
      m_Evaluation = new Evaluation(trainCopy);
      m_Evaluation.crossValidateModel(m_BaseClassifier, trainCopy, m_folds);
      repError[i] = m_Evaluation.errorRate();

      // check on the standard deviation
      if (!repeat(repError, i + 1)) {
        break;
      }
    }

    for (j = 0; j < i; j++) {
      errorRate += repError[j];
    }

    errorRate /= (double)i;
    return  -errorRate;
  }


  /**
   * Returns a string describing the wrapper
   *
   * @return the description as a string
   */
  public String toString () {
    StringBuffer text = new StringBuffer();

    if (m_trainInstances == null) {
      text.append("\tWrapper subset evaluator has not been built yet\n");
    }
    else {
      text.append("\tWrapper Subset Evaluator\n");
      text.append("\tLearning scheme: "
		  + getClassifier().getClass().getName() + "\n");
      text.append("\tScheme options: ");
      String[] classifierOptions = new String[0];

      if (m_BaseClassifier instanceof OptionHandler) {
        classifierOptions = ((OptionHandler)m_BaseClassifier).getOptions();

        for (int i = 0; i < classifierOptions.length; i++) {
          text.append(classifierOptions[i] + " ");
        }
      }

      text.append("\n");
      if (m_trainInstances.attribute(m_classIndex).isNumeric()) {
	text.append("\tAccuracy estimation: RMSE\n");
      } else {
	text.append("\tAccuracy estimation: classification error\n");
      }

      text.append("\tNumber of folds for accuracy estimation: "
		  + m_folds
		  + "\n");
    }

    return  text.toString();
  }


  /**
   * decides whether to do another repeat of cross validation. If the
   * standard deviation of the cross validations
   * is greater than threshold% of the mean (default 1%) then another
   * repeat is done.
   *
   * @param repError an array of cross validation results
   * @param entries the number of cross validations done so far
   * @return true if another cv is to be done
   */
  private boolean repeat (double[] repError, int entries) {
    int i;
    double mean = 0;
    double variance = 0;

    if (entries == 1) {
      return  true;
    }

    for (i = 0; i < entries; i++) {
      mean += repError[i];
    }

    mean /= (double)entries;

    for (i = 0; i < entries; i++) {
      variance += ((repError[i] - mean)*(repError[i] - mean));
    }

    variance /= (double)entries;

    if (variance > 0) {
      variance = Math.sqrt(variance);
    }

    if ((variance/mean) > m_threshold) {
      return  true;
    }

    return  false;
  }


  /**
   * Main method for testing this class.
   *
   * @param args the options
   */
  public static void main (String[] args) {
    try {
      System.out.println(AttributeSelection.
			 SelectAttributes(new WrapperSubsetEval(), args));
    }
    catch (Exception e) {
      log.error(e.getStackTrace().toString());
      log.error(e.getMessage());
    }
  }

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -