⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 thresholdselector.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
      throw new UnsupportedClassTypeException("Only works for two-class datasets!");
    }
    if (!instances.classAttribute().isNominal()) {
      throw new UnsupportedClassTypeException("Class attribute must be nominal!");
    }
    AttributeStats stats = instances.attributeStats(instances.classIndex());
    m_BestThreshold = 0.5;
    m_BestValue = MIN_VALUE;
    m_HighThreshold = 1;
    m_LowThreshold = 0;
    // If data contains only one instance of positive data
    // optimize on training data
    if (stats.distinctCount != 2) {
      System.err.println("Couldn't find examples of both classes. No adjustment.");
      m_Classifier.buildClassifier(instances);
    } else {
      
      // Determine which class value to look for
      switch (m_ClassMode) {
      case OPTIMIZE_0:
        m_DesignatedClass = 0;
        break;
      case OPTIMIZE_1:
        m_DesignatedClass = 1;
        break;
      case OPTIMIZE_POS_NAME:
        Attribute cAtt = instances.classAttribute();
        boolean found = false;
        for (int i = 0; i < cAtt.numValues() && !found; i++) {
          String name = cAtt.value(i).toLowerCase();
          if (name.startsWith("yes") || name.equals("1") || 
              name.startsWith("pos")) {
            found = true;
            m_DesignatedClass = i;
          }
        }
        if (found) {
          break;
        }
        // No named class found, so fall through to default of least frequent
      case OPTIMIZE_LFREQ:
        m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 1 : 0;
        break;
      case OPTIMIZE_MFREQ:
        m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 0 : 1;
        break;
      default:
        throw new Exception("Unrecognized class value selection mode");
      }
      
      /*
        System.err.println("ThresholdSelector: Using mode=" 
        + TAGS_OPTIMIZE[m_ClassMode].getReadable());
        System.err.println("ThresholdSelector: Optimizing using class "
        + m_DesignatedClass + "/" 
        + instances.classAttribute().value(m_DesignatedClass));
      */
      
      
      if (stats.nominalCounts[m_DesignatedClass] == 1) {
        System.err.println("Only 1 positive found: optimizing on training data");
        findThreshold(getPredictions(instances, EVAL_TRAINING_SET, 0));
      } else {
        int numFolds = Math.min(m_NumXValFolds, stats.nominalCounts[m_DesignatedClass]);
        //System.err.println("Number of folds for threshold selector: " + numFolds);
        findThreshold(getPredictions(instances, m_EvalMode, numFolds));
        if (m_EvalMode != EVAL_TRAINING_SET) {
          m_Classifier.buildClassifier(instances);
        }
      }
    }
  }

  /**
   * Checks whether instance of designated class is in subset.
   */
  private boolean checkForInstance(Instances data) throws Exception {

    for (int i = 0; i < data.numInstances(); i++) {
      if (((int)data.instance(i).classValue()) == m_DesignatedClass) {
	return true;
      }
    }
    return false;
  }


  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @exception Exception if instance could not be classified
   * successfully
   */
  public double [] distributionForInstance(Instance instance) 
    throws Exception {
    
    double [] pred = m_Classifier.distributionForInstance(instance);
    double prob = pred[m_DesignatedClass];

    // Warp probability
    if (prob > m_BestThreshold) {
      prob = 0.5 + (prob - m_BestThreshold) / 
        ((m_HighThreshold - m_BestThreshold) * 2);
    } else {
      prob = (prob - m_LowThreshold) / 
        ((m_BestThreshold - m_LowThreshold) * 2);
    }
    if (prob < 0) {
      prob = 0.0;
    } else if (prob > 1) {
      prob = 1.0;
    }

    // Alter the distribution
    pred[m_DesignatedClass] = prob;
    if (pred.length == 2) { // Handle case when there's only one class
      pred[(m_DesignatedClass + 1) % 2] = 1.0 - prob;
    }
    return pred;
  }

  /**
   * @return a description of the classifier suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {

    return "A metaclassifier that selecting a mid-point threshold on the "
      + "probability output by a Classifier. The midpoint "
      + "threshold is set so that a given performance measure is optimized. "
      + "Currently this is the F-measure. Performance is measured either on "
      + "the training data, a hold-out set or using cross-validation. In "
      + "addition, the probabilities returned by the base learner can "
      + "have their range expanded so that the output probabilities will "
      + "reside between 0 and 1 (this is useful if the scheme normally "
      + "produces probabilities in a very narrow range).";
  }
    
  /**
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String designatedClassTipText() {

    return "Sets the class value for which the optimization is performed. "
      + "The options are: pick the first class value; pick the second "
      + "class value; pick whichever class is least frequent; pick whichever "
      + "class value is most frequent; pick the first class named any of "
      + "\"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches).";
  }

  /**
   * Gets the method to determine which class value to optimize. Will
   * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
   * OPTIMIZE_POS_NAME.
   *
   * @return the class selection mode.
   */
  public SelectedTag getDesignatedClass() {

    return new SelectedTag(m_ClassMode, TAGS_OPTIMIZE);
  }
  
  /**
   * Sets the method to determine which class value to optimize. Will
   * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
   * OPTIMIZE_POS_NAME.
   *
   * @param newMethod the new class selection mode.
   */
  public void setDesignatedClass(SelectedTag newMethod) {
    
    if (newMethod.getTags() == TAGS_OPTIMIZE) {
      m_ClassMode = newMethod.getSelectedTag().getID();
    }
  }

  /**
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String evaluationModeTipText() {

    return "Sets the method used to determine the threshold/performance "
      + "curve. The options are: perform optimization based on the entire "
      + "training set (may result in overfitting); perform an n-fold "
      + "cross-validation (may be time consuming); perform one fold of "
      + "an n-fold cross-validation (faster but likely less accurate).";
  }

  /**
   * Sets the evaluation mode used. Will be one of
   * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
   *
   * @param newMethod the new evaluation mode.
   */
  public void setEvaluationMode(SelectedTag newMethod) {
    
    if (newMethod.getTags() == TAGS_EVAL) {
      m_EvalMode = newMethod.getSelectedTag().getID();
    }
  }

  /**
   * Gets the evaluation mode used. Will be one of
   * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
   *
   * @return the evaluation mode.
   */
  public SelectedTag getEvaluationMode() {

    return new SelectedTag(m_EvalMode, TAGS_EVAL);
  }

  /**
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String rangeCorrectionTipText() {

    return "Sets the type of prediction range correction performed. "
      + "The options are: do not do any range correction; "
      + "expand predicted probabilities so that the minimum probability "
      + "observed during the optimization maps to 0, and the maximum "
      + "maps to 1 (values outside this range are clipped to 0 and 1).";
  }

  /**
   * Sets the confidence range correction mode used. Will be one of
   * RANGE_NONE, or RANGE_BOUNDS
   *
   * @param newMethod the new correciton mode.
   */
  public void setRangeCorrection(SelectedTag newMethod) {
    
    if (newMethod.getTags() == TAGS_RANGE) {
      m_RangeMode = newMethod.getSelectedTag().getID();
    }
  }

  /**
   * Gets the confidence range correction mode used. Will be one of
   * RANGE_NONE, or RANGE_BOUNDS
   *
   * @return the confidence correction mode.
   */
  public SelectedTag getRangeCorrection() {

    return new SelectedTag(m_RangeMode, TAGS_RANGE);
  }
  
  /**
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String seedTipText() {

    return "Sets the seed used for randomization. This is used when "
      + "randomizing the data during optimization.";
  }

  /**
   * Sets the seed for random number generation.
   *
   * @param seed the random number seed
   */
  public void setSeed(int seed) {
    
    m_Seed = seed;
  }

  /**
   * Gets the random number seed.
   * 
   * @return the random number seed
   */
  public int getSeed() {

    return m_Seed;
  }

  /**
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String numXValFoldsTipText() {

    return "Sets the number of folds used during full cross-validation "
      + "and tuned fold evaluation. This number will be automatically "
      + "reduced if there are insufficient positive examples.";
  }

  /**
   * Get the number of folds used for cross-validation.
   *
   * @return the number of folds used for cross-validation.
   */
  public int getNumXValFolds() {
    
    return m_NumXValFolds;
  }
  
  /**
   * Set the number of folds used for cross-validation.
   *
   * @param newNumFolds the number of folds used for cross-validation.
   */
  public void setNumXValFolds(int newNumFolds) {
    
    if (newNumFolds < 2) {
      throw new IllegalArgumentException("Number of folds must be greater than 1");
    }
    m_NumXValFolds = newNumFolds;
  }

  /**
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String classifierTipText() {

    return "Sets the base Classifier to which the optimization "
      + "will be made.";
  }

  /**
   * Set the Classifier for which threshold is set. 
   *
   * @param newClassifier the Classifier to use.
   */
  public void setClassifier(Classifier newClassifier) {

    m_Classifier = newClassifier;
  }

  /**
   * Get the Classifier used as the classifier.
   *
   * @return the classifier used as the classifier
   */
  public Classifier getClassifier() {

    return m_Classifier;
  }
 
  /**
   * Gets the classifier specification string, which contains the class name of
   * the classifier and any options to the classifier
   *
   * @return the classifier string.
   */
  protected String getClassifierSpec() {
    
    Classifier c = getClassifier();
    if (c instanceof OptionHandler) {
      return c.getClass().getName() + " "
	+ Utils.joinOptions(((OptionHandler)c).getOptions());
    }
    return c.getClass().getName();
  }

  /**
   *  Returns the type of graph this classifier
   *  represents.
   */   
  public int graphType() {
    
    if (m_Classifier instanceof Drawable)
      return ((Drawable)m_Classifier).graphType();
    else 
      return Drawable.NOT_DRAWABLE;
  }

  /**
   * Returns graph describing the classifier (if possible).
   *
   * @return the graph of the classifier in dotty format
   * @exception Exception if the classifier cannot be graphed
   */
  public String graph() throws Exception {
    
    if (m_Classifier instanceof Drawable)
      return ((Drawable)m_Classifier).graph();
    else throw new Exception("Classifier: " + getClassifierSpec()
			     + " cannot be graphed");
  }
 
  /**
   * Returns description of the cross-validated classifier.
   *
   * @return description of the cross-validated classifier as a string
   */
  public String toString() {

    if (m_BestValue == -Double.MAX_VALUE)
      return "ThresholdSelector: No model built yet.";

    String result = "Threshold Selector.\n"
    + "Classifier: " + m_Classifier.getClass().getName() + "\n";

    result += "Index of designated class: " + m_DesignatedClass + "\n";

    result += "Evaluation mode: ";
    switch (m_EvalMode) {
    case EVAL_CROSS_VALIDATION:
      result += m_NumXValFolds + "-fold cross-validation";
      break;
    case EVAL_TUNED_SPLIT:
      result += "tuning on 1/" + m_NumXValFolds + " of the data";
      break;
    case EVAL_TRAINING_SET:
    default:
      result += "tuning on the training data";
    }
    result += "\n";

    result += "Threshold: " + m_BestThreshold + "\n";
    result += "Best value: " + m_BestValue + "\n";
    if (m_RangeMode == RANGE_BOUNDS) {
      result += "Expanding range [" + m_LowThreshold + "," + m_HighThreshold
        + "] to [0, 1]\n";
    }
    result += m_Classifier.toString();
    return result;
  }
  
  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String [] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new ThresholdSelector(), 
						  argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -