⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 thresholdselector.java

📁 数据挖掘classifiers算法
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    System.arraycopy(classifierOptions, 0, options, current, 		     classifierOptions.length);    current += classifierOptions.length;    while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Generates the classifier.   *   * @param instances set of instances serving as training data    * @exception Exception if the classifier has not been generated successfully   */  public void buildClassifier(Instances instances)     throws Exception {    if (instances.numClasses() > 2) {      throw new UnsupportedClassTypeException("Only works for two-class datasets!");    }    if (!instances.classAttribute().isNominal()) {      throw new UnsupportedClassTypeException("Class attribute must be nominal!");    }    AttributeStats stats = instances.attributeStats(instances.classIndex());    m_BestThreshold = 0.5;    m_BestValue = MIN_VALUE;    m_HighThreshold = 1;    m_LowThreshold = 0;    // If data contains only one instance of positive data    // optimize on training data    if (stats.distinctCount != 2) {      System.err.println("Couldn't find examples of both classes. No adjustment.");      m_Classifier.buildClassifier(instances);    } else {            // Determine which class value to look for      switch (m_ClassMode) {      case OPTIMIZE_0:        m_DesignatedClass = 0;        break;      case OPTIMIZE_1:        m_DesignatedClass = 1;        break;      case OPTIMIZE_POS_NAME:        Attribute cAtt = instances.classAttribute();        boolean found = false;        for (int i = 0; i < cAtt.numValues() && !found; i++) {          String name = cAtt.value(i).toLowerCase();          if (name.startsWith("yes") || name.equals("1") ||               name.startsWith("pos")) {            found = true;            m_DesignatedClass = i;          }        }        if (found) {          break;        }        // No named class found, so fall through to default of least frequent      case OPTIMIZE_LFREQ:        m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 1 : 0;        break;      case OPTIMIZE_MFREQ:        m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 0 : 1;        break;      default:        throw new Exception("Unrecognized class value selection mode");      }            /*        System.err.println("ThresholdSelector: Using mode="         + TAGS_OPTIMIZE[m_ClassMode].getReadable());        System.err.println("ThresholdSelector: Optimizing using class "        + m_DesignatedClass + "/"         + instances.classAttribute().value(m_DesignatedClass));      */                  if (stats.nominalCounts[m_DesignatedClass] == 1) {        System.err.println("Only 1 positive found: optimizing on training data");        findThreshold(getPredictions(instances, EVAL_TRAINING_SET, 0));      } else {        int numFolds = Math.min(m_NumXValFolds, stats.nominalCounts[m_DesignatedClass]);        //System.err.println("Number of folds for threshold selector: " + numFolds);        findThreshold(getPredictions(instances, m_EvalMode, numFolds));        if (m_EvalMode != EVAL_TRAINING_SET) {          m_Classifier.buildClassifier(instances);        }      }    }  }  /**   * Checks whether instance of designated class is in subset.   */  private boolean checkForInstance(Instances data) throws Exception {    for (int i = 0; i < data.numInstances(); i++) {      if (((int)data.instance(i).classValue()) == m_DesignatedClass) {	return true;      }    }    return false;  }  /**   * Calculates the class membership probabilities for the given test instance.   *   * @param instance the instance to be classified   * @return predicted class probability distribution   * @exception Exception if instance could not be classified   * successfully   */  public double [] distributionForInstance(Instance instance)     throws Exception {        double [] pred = m_Classifier.distributionForInstance(instance);    double prob = pred[m_DesignatedClass];    // Warp probability    if (prob > m_BestThreshold) {      prob = 0.5 + (prob - m_BestThreshold) /         ((m_HighThreshold - m_BestThreshold) * 2);    } else {      prob = (prob - m_LowThreshold) /         ((m_BestThreshold - m_LowThreshold) * 2);    }    if (prob < 0) {      prob = 0.0;    } else if (prob > 1) {      prob = 1.0;    }    // Alter the distribution    pred[m_DesignatedClass] = prob;    if (pred.length == 2) { // Handle case when there's only one class      pred[(m_DesignatedClass + 1) % 2] = 1.0 - prob;    }    return pred;  }  /**   * @return a description of the classifier suitable for   * displaying in the explorer/experimenter gui   */  public String globalInfo() {    return "A metaclassifier that selecting a mid-point threshold on the "      + "probability output by a DistributionClassifier. The midpoint "      + "threshold is set so that a given performance measure is optimized. "      + "Currently this is the F-measure. Performance is measured either on "      + "the training data, a hold-out set or using cross-validation. In "      + "addition, the probabilities returned by the base learner can "      + "have their range expanded so that the output probabilities will "      + "reside between 0 and 1 (this is useful if the scheme normally "      + "produces probabilities in a very narrow range).";  }      /**   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String designatedClassTipText() {    return "Sets the class value for which the optimization is performed. "      + "The options are: pick the first class value; pick the second "      + "class value; pick whichever class is least frequent; pick whichever "      + "class value is most frequent; pick the first class named any of "      + "\"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches).";  }  /**   * Gets the method to determine which class value to optimize. Will   * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,   * OPTIMIZE_POS_NAME.   *   * @return the class selection mode.   */  public SelectedTag getDesignatedClass() {    return new SelectedTag(m_ClassMode, TAGS_OPTIMIZE);  }    /**   * Sets the method to determine which class value to optimize. Will   * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,   * OPTIMIZE_POS_NAME.   *   * @param newMethod the new class selection mode.   */  public void setDesignatedClass(SelectedTag newMethod) {        if (newMethod.getTags() == TAGS_OPTIMIZE) {      m_ClassMode = newMethod.getSelectedTag().getID();    }  }  /**   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String evaluationModeTipText() {    return "Sets the method used to determine the threshold/performance "      + "curve. The options are: perform optimization based on the entire "      + "training set (may result in overfitting); perform an n-fold "      + "cross-validation (may be time consuming); perform one fold of "      + "an n-fold cross-validation (faster but likely less accurate).";  }  /**   * Sets the evaluation mode used. Will be one of   * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION   *   * @param newMethod the new evaluation mode.   */  public void setEvaluationMode(SelectedTag newMethod) {        if (newMethod.getTags() == TAGS_EVAL) {      m_EvalMode = newMethod.getSelectedTag().getID();    }  }  /**   * Gets the evaluation mode used. Will be one of   * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION   *   * @return the evaluation mode.   */  public SelectedTag getEvaluationMode() {    return new SelectedTag(m_EvalMode, TAGS_EVAL);  }  /**   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String rangeCorrectionTipText() {    return "Sets the type of prediction range correction performed. "      + "The options are: do not do any range correction; "      + "expand predicted probabilities so that the minimum probability "      + "observed during the optimization maps to 0, and the maximum "      + "maps to 1 (values outside this range are clipped to 0 and 1).";  }  /**   * Sets the confidence range correction mode used. Will be one of   * RANGE_NONE, or RANGE_BOUNDS   *   * @param newMethod the new correciton mode.   */  public void setRangeCorrection(SelectedTag newMethod) {        if (newMethod.getTags() == TAGS_RANGE) {      m_RangeMode = newMethod.getSelectedTag().getID();    }  }  /**   * Gets the confidence range correction mode used. Will be one of   * RANGE_NONE, or RANGE_BOUNDS   *   * @return the confidence correction mode.   */  public SelectedTag getRangeCorrection() {    return new SelectedTag(m_RangeMode, TAGS_RANGE);  }    /**   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String seedTipText() {    return "Sets the seed used for randomization. This is used when "      + "randomizing the data during optimization.";  }  /**   * Sets the seed for random number generation.   *   * @param seed the random number seed   */  public void setSeed(int seed) {        m_Seed = seed;  }  /**   * Gets the random number seed.   *    * @return the random number seed   */  public int getSeed() {    return m_Seed;  }  /**   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String numXValFoldsTipText() {    return "Sets the number of folds used during full cross-validation "      + "and tuned fold evaluation. This number will be automatically "      + "reduced if there are insufficient positive examples.";  }  /**   * Get the number of folds used for cross-validation.   *   * @return the number of folds used for cross-validation.   */  public int getNumXValFolds() {        return m_NumXValFolds;  }    /**   * Set the number of folds used for cross-validation.   *   * @param newNumFolds the number of folds used for cross-validation.   */  public void setNumXValFolds(int newNumFolds) {        if (newNumFolds < 2) {      throw new IllegalArgumentException("Number of folds must be greater than 1");    }    m_NumXValFolds = newNumFolds;  }  /**   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String distributionClassifierTipText() {    return "Sets the base DistributionClassifier to which the optimization "      + "will be made.";  }  /**   * Set the DistributionClassifier for which threshold is set.    *   * @param newClassifier the Classifier to use.   */  public void setDistributionClassifier(DistributionClassifier newClassifier) {    m_Classifier = newClassifier;  }  /**   * Get the DistributionClassifier used as the classifier.   *   * @return the classifier used as the classifier   */  public DistributionClassifier getDistributionClassifier() {    return m_Classifier;  }   /**   * Returns description of the cross-validated classifier.   *   * @return description of the cross-validated classifier as a string   */  public String toString() {    if (m_BestValue == -Double.MAX_VALUE)      return "ThresholdSelector: No model built yet.";    String result = "Threshold Selector.\n"    + "Classifier: " + m_Classifier.getClass().getName() + "\n";    result += "Index of designated class: " + m_DesignatedClass + "\n";    result += "Evaluation mode: ";    switch (m_EvalMode) {    case EVAL_CROSS_VALIDATION:      result += m_NumXValFolds + "-fold cross-validation";      break;    case EVAL_TUNED_SPLIT:      result += "tuning on 1/" + m_NumXValFolds + " of the data";      break;    case EVAL_TRAINING_SET:    default:      result += "tuning on the training data";    }    result += "\n";    result += "Threshold: " + m_BestThreshold + "\n";    result += "Best value: " + m_BestValue + "\n";    if (m_RangeMode == RANGE_BOUNDS) {      result += "Expanding range [" + m_LowThreshold + "," + m_HighThreshold        + "] to [0, 1]\n";    }    result += m_Classifier.toString();    return result;  }    /**   * Main method for testing this class.   *   * @param argv the options   */  public static void main(String [] argv) {    try {      System.out.println(Evaluation.evaluateModel(new ThresholdSelector(), 						  argv));    } catch (Exception e) {      System.err.println(e.getMessage());    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -