📄 thresholdselector.java
字号:
throw new UnsupportedClassTypeException("Only works for two-class datasets!");
}
if (!instances.classAttribute().isNominal()) {
throw new UnsupportedClassTypeException("Class attribute must be nominal!");
}
AttributeStats stats = instances.attributeStats(instances.classIndex());
m_BestThreshold = 0.5;
m_BestValue = MIN_VALUE;
m_HighThreshold = 1;
m_LowThreshold = 0;
// If data contains only one instance of positive data
// optimize on training data
if (stats.distinctCount != 2) {
System.err.println("Couldn't find examples of both classes. No adjustment.");
m_Classifier.buildClassifier(instances);
} else {
// Determine which class value to look for
switch (m_ClassMode) {
case OPTIMIZE_0:
m_DesignatedClass = 0;
break;
case OPTIMIZE_1:
m_DesignatedClass = 1;
break;
case OPTIMIZE_POS_NAME:
Attribute cAtt = instances.classAttribute();
boolean found = false;
for (int i = 0; i < cAtt.numValues() && !found; i++) {
String name = cAtt.value(i).toLowerCase();
if (name.startsWith("yes") || name.equals("1") ||
name.startsWith("pos")) {
found = true;
m_DesignatedClass = i;
}
}
if (found) {
break;
}
// No named class found, so fall through to default of least frequent
case OPTIMIZE_LFREQ:
m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 1 : 0;
break;
case OPTIMIZE_MFREQ:
m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 0 : 1;
break;
default:
throw new Exception("Unrecognized class value selection mode");
}
/*
System.err.println("ThresholdSelector: Using mode="
+ TAGS_OPTIMIZE[m_ClassMode].getReadable());
System.err.println("ThresholdSelector: Optimizing using class "
+ m_DesignatedClass + "/"
+ instances.classAttribute().value(m_DesignatedClass));
*/
if (stats.nominalCounts[m_DesignatedClass] == 1) {
System.err.println("Only 1 positive found: optimizing on training data");
findThreshold(getPredictions(instances, EVAL_TRAINING_SET, 0));
} else {
int numFolds = Math.min(m_NumXValFolds, stats.nominalCounts[m_DesignatedClass]);
//System.err.println("Number of folds for threshold selector: " + numFolds);
findThreshold(getPredictions(instances, m_EvalMode, numFolds));
if (m_EvalMode != EVAL_TRAINING_SET) {
m_Classifier.buildClassifier(instances);
}
}
}
}
/**
* Checks whether instance of designated class is in subset.
*/
private boolean checkForInstance(Instances data) throws Exception {
for (int i = 0; i < data.numInstances(); i++) {
if (((int)data.instance(i).classValue()) == m_DesignatedClass) {
return true;
}
}
return false;
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if instance could not be classified
* successfully
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
double [] pred = m_Classifier.distributionForInstance(instance);
double prob = pred[m_DesignatedClass];
// Warp probability
if (prob > m_BestThreshold) {
prob = 0.5 + (prob - m_BestThreshold) /
((m_HighThreshold - m_BestThreshold) * 2);
} else {
prob = (prob - m_LowThreshold) /
((m_BestThreshold - m_LowThreshold) * 2);
}
if (prob < 0) {
prob = 0.0;
} else if (prob > 1) {
prob = 1.0;
}
// Alter the distribution
pred[m_DesignatedClass] = prob;
if (pred.length == 2) { // Handle case when there's only one class
pred[(m_DesignatedClass + 1) % 2] = 1.0 - prob;
}
return pred;
}
/**
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "A metaclassifier that selecting a mid-point threshold on the "
+ "probability output by a Classifier. The midpoint "
+ "threshold is set so that a given performance measure is optimized. "
+ "Currently this is the F-measure. Performance is measured either on "
+ "the training data, a hold-out set or using cross-validation. In "
+ "addition, the probabilities returned by the base learner can "
+ "have their range expanded so that the output probabilities will "
+ "reside between 0 and 1 (this is useful if the scheme normally "
+ "produces probabilities in a very narrow range).";
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String designatedClassTipText() {
return "Sets the class value for which the optimization is performed. "
+ "The options are: pick the first class value; pick the second "
+ "class value; pick whichever class is least frequent; pick whichever "
+ "class value is most frequent; pick the first class named any of "
+ "\"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches).";
}
/**
* Gets the method to determine which class value to optimize. Will
* be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
* OPTIMIZE_POS_NAME.
*
* @return the class selection mode.
*/
public SelectedTag getDesignatedClass() {
return new SelectedTag(m_ClassMode, TAGS_OPTIMIZE);
}
/**
* Sets the method to determine which class value to optimize. Will
* be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
* OPTIMIZE_POS_NAME.
*
* @param newMethod the new class selection mode.
*/
public void setDesignatedClass(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_OPTIMIZE) {
m_ClassMode = newMethod.getSelectedTag().getID();
}
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String evaluationModeTipText() {
return "Sets the method used to determine the threshold/performance "
+ "curve. The options are: perform optimization based on the entire "
+ "training set (may result in overfitting); perform an n-fold "
+ "cross-validation (may be time consuming); perform one fold of "
+ "an n-fold cross-validation (faster but likely less accurate).";
}
/**
* Sets the evaluation mode used. Will be one of
* EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
*
* @param newMethod the new evaluation mode.
*/
public void setEvaluationMode(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_EVAL) {
m_EvalMode = newMethod.getSelectedTag().getID();
}
}
/**
* Gets the evaluation mode used. Will be one of
* EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
*
* @return the evaluation mode.
*/
public SelectedTag getEvaluationMode() {
return new SelectedTag(m_EvalMode, TAGS_EVAL);
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String rangeCorrectionTipText() {
return "Sets the type of prediction range correction performed. "
+ "The options are: do not do any range correction; "
+ "expand predicted probabilities so that the minimum probability "
+ "observed during the optimization maps to 0, and the maximum "
+ "maps to 1 (values outside this range are clipped to 0 and 1).";
}
/**
* Sets the confidence range correction mode used. Will be one of
* RANGE_NONE, or RANGE_BOUNDS
*
* @param newMethod the new correciton mode.
*/
public void setRangeCorrection(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_RANGE) {
m_RangeMode = newMethod.getSelectedTag().getID();
}
}
/**
* Gets the confidence range correction mode used. Will be one of
* RANGE_NONE, or RANGE_BOUNDS
*
* @return the confidence correction mode.
*/
public SelectedTag getRangeCorrection() {
return new SelectedTag(m_RangeMode, TAGS_RANGE);
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String seedTipText() {
return "Sets the seed used for randomization. This is used when "
+ "randomizing the data during optimization.";
}
/**
* Sets the seed for random number generation.
*
* @param seed the random number seed
*/
public void setSeed(int seed) {
m_Seed = seed;
}
/**
* Gets the random number seed.
*
* @return the random number seed
*/
public int getSeed() {
return m_Seed;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numXValFoldsTipText() {
return "Sets the number of folds used during full cross-validation "
+ "and tuned fold evaluation. This number will be automatically "
+ "reduced if there are insufficient positive examples.";
}
/**
* Get the number of folds used for cross-validation.
*
* @return the number of folds used for cross-validation.
*/
public int getNumXValFolds() {
return m_NumXValFolds;
}
/**
* Set the number of folds used for cross-validation.
*
* @param newNumFolds the number of folds used for cross-validation.
*/
public void setNumXValFolds(int newNumFolds) {
if (newNumFolds < 2) {
throw new IllegalArgumentException("Number of folds must be greater than 1");
}
m_NumXValFolds = newNumFolds;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String classifierTipText() {
return "Sets the base Classifier to which the optimization "
+ "will be made.";
}
/**
* Set the Classifier for which threshold is set.
*
* @param newClassifier the Classifier to use.
*/
public void setClassifier(Classifier newClassifier) {
m_Classifier = newClassifier;
}
/**
* Get the Classifier used as the classifier.
*
* @return the classifier used as the classifier
*/
public Classifier getClassifier() {
return m_Classifier;
}
/**
* Gets the classifier specification string, which contains the class name of
* the classifier and any options to the classifier
*
* @return the classifier string.
*/
protected String getClassifierSpec() {
Classifier c = getClassifier();
if (c instanceof OptionHandler) {
return c.getClass().getName() + " "
+ Utils.joinOptions(((OptionHandler)c).getOptions());
}
return c.getClass().getName();
}
/**
* Returns the type of graph this classifier
* represents.
*/
public int graphType() {
if (m_Classifier instanceof Drawable)
return ((Drawable)m_Classifier).graphType();
else
return Drawable.NOT_DRAWABLE;
}
/**
* Returns graph describing the classifier (if possible).
*
* @return the graph of the classifier in dotty format
* @exception Exception if the classifier cannot be graphed
*/
public String graph() throws Exception {
if (m_Classifier instanceof Drawable)
return ((Drawable)m_Classifier).graph();
else throw new Exception("Classifier: " + getClassifierSpec()
+ " cannot be graphed");
}
/**
* Returns description of the cross-validated classifier.
*
* @return description of the cross-validated classifier as a string
*/
public String toString() {
if (m_BestValue == -Double.MAX_VALUE)
return "ThresholdSelector: No model built yet.";
String result = "Threshold Selector.\n"
+ "Classifier: " + m_Classifier.getClass().getName() + "\n";
result += "Index of designated class: " + m_DesignatedClass + "\n";
result += "Evaluation mode: ";
switch (m_EvalMode) {
case EVAL_CROSS_VALIDATION:
result += m_NumXValFolds + "-fold cross-validation";
break;
case EVAL_TUNED_SPLIT:
result += "tuning on 1/" + m_NumXValFolds + " of the data";
break;
case EVAL_TRAINING_SET:
default:
result += "tuning on the training data";
}
result += "\n";
result += "Threshold: " + m_BestThreshold + "\n";
result += "Best value: " + m_BestValue + "\n";
if (m_RangeMode == RANGE_BOUNDS) {
result += "Expanding range [" + m_LowThreshold + "," + m_HighThreshold
+ "] to [0, 1]\n";
}
result += m_Classifier.toString();
return result;
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new ThresholdSelector(),
argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -