📄 thresholdselector.java
字号:
break; default: throw new Exception("Unrecognized class value selection mode"); } /* System.err.println("ThresholdSelector: Using mode=" + TAGS_OPTIMIZE[m_ClassMode].getReadable()); System.err.println("ThresholdSelector: Optimizing using class " + m_DesignatedClass + "/" + instances.classAttribute().value(m_DesignatedClass)); */ if (stats.nominalCounts[m_DesignatedClass] == 1) { System.err.println("Only 1 positive found: optimizing on training data"); findThreshold(getPredictions(instances, EVAL_TRAINING_SET, 0)); } else { int numFolds = Math.min(m_NumXValFolds, stats.nominalCounts[m_DesignatedClass]); //System.err.println("Number of folds for threshold selector: " + numFolds); findThreshold(getPredictions(instances, m_EvalMode, numFolds)); if (m_EvalMode != EVAL_TRAINING_SET) { m_Classifier.buildClassifier(instances); } } } } /** * Checks whether instance of designated class is in subset. * * @param data the data to check for instance * @return true if the instance is in the subset * @throws Exception if checking fails */ private boolean checkForInstance(Instances data) throws Exception { for (int i = 0; i < data.numInstances(); i++) { if (((int)data.instance(i).classValue()) == m_DesignatedClass) { return true; } } return false; } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if instance could not be classified * successfully */ public double [] distributionForInstance(Instance instance) throws Exception { double [] pred = m_Classifier.distributionForInstance(instance); double prob = pred[m_DesignatedClass]; // Warp probability if (prob > m_BestThreshold) { prob = 0.5 + (prob - m_BestThreshold) / ((m_HighThreshold - m_BestThreshold) * 2); } else { prob = (prob - m_LowThreshold) / ((m_BestThreshold - m_LowThreshold) * 2); } if (prob < 0) { prob = 0.0; } else if (prob > 1) { prob = 1.0; } // Alter the distribution pred[m_DesignatedClass] = prob; if (pred.length == 2) { // Handle case when there's only one class pred[(m_DesignatedClass + 1) % 2] = 1.0 - prob; } return pred; } /** * @return a description of the classifier suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "A metaclassifier that selecting a mid-point threshold on the " + "probability output by a Classifier. The midpoint " + "threshold is set so that a given performance measure is optimized. " + "Currently this is the F-measure. Performance is measured either on " + "the training data, a hold-out set or using cross-validation. In " + "addition, the probabilities returned by the base learner can " + "have their range expanded so that the output probabilities will " + "reside between 0 and 1 (this is useful if the scheme normally " + "produces probabilities in a very narrow range)."; } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String designatedClassTipText() { return "Sets the class value for which the optimization is performed. " + "The options are: pick the first class value; pick the second " + "class value; pick whichever class is least frequent; pick whichever " + "class value is most frequent; pick the first class named any of " + "\"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches)."; } /** * Gets the method to determine which class value to optimize. Will * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ, * OPTIMIZE_POS_NAME. * * @return the class selection mode. */ public SelectedTag getDesignatedClass() { return new SelectedTag(m_ClassMode, TAGS_OPTIMIZE); } /** * Sets the method to determine which class value to optimize. Will * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ, * OPTIMIZE_POS_NAME. * * @param newMethod the new class selection mode. */ public void setDesignatedClass(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_OPTIMIZE) { m_ClassMode = newMethod.getSelectedTag().getID(); } } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String evaluationModeTipText() { return "Sets the method used to determine the threshold/performance " + "curve. The options are: perform optimization based on the entire " + "training set (may result in overfitting); perform an n-fold " + "cross-validation (may be time consuming); perform one fold of " + "an n-fold cross-validation (faster but likely less accurate)."; } /** * Sets the evaluation mode used. Will be one of * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION * * @param newMethod the new evaluation mode. */ public void setEvaluationMode(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_EVAL) { m_EvalMode = newMethod.getSelectedTag().getID(); } } /** * Gets the evaluation mode used. Will be one of * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION * * @return the evaluation mode. */ public SelectedTag getEvaluationMode() { return new SelectedTag(m_EvalMode, TAGS_EVAL); } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String rangeCorrectionTipText() { return "Sets the type of prediction range correction performed. " + "The options are: do not do any range correction; " + "expand predicted probabilities so that the minimum probability " + "observed during the optimization maps to 0, and the maximum " + "maps to 1 (values outside this range are clipped to 0 and 1)."; } /** * Sets the confidence range correction mode used. Will be one of * RANGE_NONE, or RANGE_BOUNDS * * @param newMethod the new correciton mode. */ public void setRangeCorrection(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_RANGE) { m_RangeMode = newMethod.getSelectedTag().getID(); } } /** * Gets the confidence range correction mode used. Will be one of * RANGE_NONE, or RANGE_BOUNDS * * @return the confidence correction mode. */ public SelectedTag getRangeCorrection() { return new SelectedTag(m_RangeMode, TAGS_RANGE); } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numXValFoldsTipText() { return "Sets the number of folds used during full cross-validation " + "and tuned fold evaluation. This number will be automatically " + "reduced if there are insufficient positive examples."; } /** * Get the number of folds used for cross-validation. * * @return the number of folds used for cross-validation. */ public int getNumXValFolds() { return m_NumXValFolds; } /** * Set the number of folds used for cross-validation. * * @param newNumFolds the number of folds used for cross-validation. */ public void setNumXValFolds(int newNumFolds) { if (newNumFolds < 2) { throw new IllegalArgumentException("Number of folds must be greater than 1"); } m_NumXValFolds = newNumFolds; } /** * Returns the type of graph this classifier * represents. * * @return the type of graph this classifier represents */ public int graphType() { if (m_Classifier instanceof Drawable) return ((Drawable)m_Classifier).graphType(); else return Drawable.NOT_DRAWABLE; } /** * Returns graph describing the classifier (if possible). * * @return the graph of the classifier in dotty format * @throws Exception if the classifier cannot be graphed */ public String graph() throws Exception { if (m_Classifier instanceof Drawable) return ((Drawable)m_Classifier).graph(); else throw new Exception("Classifier: " + getClassifierSpec() + " cannot be graphed"); } /** * Returns description of the cross-validated classifier. * * @return description of the cross-validated classifier as a string */ public String toString() { if (m_BestValue == -Double.MAX_VALUE) return "ThresholdSelector: No model built yet."; String result = "Threshold Selector.\n" + "Classifier: " + m_Classifier.getClass().getName() + "\n"; result += "Index of designated class: " + m_DesignatedClass + "\n"; result += "Evaluation mode: "; switch (m_EvalMode) { case EVAL_CROSS_VALIDATION: result += m_NumXValFolds + "-fold cross-validation"; break; case EVAL_TUNED_SPLIT: result += "tuning on 1/" + m_NumXValFolds + " of the data"; break; case EVAL_TRAINING_SET: default: result += "tuning on the training data"; } result += "\n"; result += "Threshold: " + m_BestThreshold + "\n"; result += "Best value: " + m_BestValue + "\n"; if (m_RangeMode == RANGE_BOUNDS) { result += "Expanding range [" + m_LowThreshold + "," + m_HighThreshold + "] to [0, 1]\n"; } result += "Measure: " + getMeasure().getSelectedTag().getReadable() + "\n"; result += m_Classifier.toString(); return result; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { runClassifier(new ThresholdSelector(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -