📄 thresholdselector.java
字号:
break; case TRUE_POS: index1 = curve.attribute(ThresholdCurve.TRUE_POS_NAME).index(); maxValue = maxInst.value(index1); break; case TRUE_NEG: index1 = curve.attribute(ThresholdCurve.TRUE_NEG_NAME).index(); maxValue = maxInst.value(index1); break; case TP_RATE: index1 = curve.attribute(ThresholdCurve.TP_RATE_NAME).index(); maxValue = maxInst.value(index1); break; case PRECISION: index1 = curve.attribute(ThresholdCurve.PRECISION_NAME).index(); maxValue = maxInst.value(index1); break; case RECALL: index1 = curve.attribute(ThresholdCurve.RECALL_NAME).index(); maxValue = maxInst.value(index1); break; case ACCURACY: index1 = curve.attribute(ThresholdCurve.TRUE_POS_NAME).index(); index2 = curve.attribute(ThresholdCurve.TRUE_NEG_NAME).index(); maxValue = maxInst.value(index1) + maxInst.value(index2); break; } int indexThreshold = curve.attribute(ThresholdCurve.THRESHOLD_NAME).index(); for (int i = 1; i < curve.numInstances(); i++) { Instance current = curve.instance(i); double currentValue = 0; if (m_nMeasure == ACCURACY) { currentValue= current.value(index1) + current.value(index2); } else { currentValue= current.value(index1); } if (currentValue> maxValue) { maxInst = current; maxValue = currentValue; } if (m_RangeMode == RANGE_BOUNDS) { double thresh = current.value(indexThreshold); if (thresh < low) { low = thresh; } if (thresh > high) { high = thresh; } } } if (maxValue > MIN_VALUE) { m_BestThreshold = maxInst.value(indexThreshold); m_BestValue = maxValue; //System.err.println("maxFM: " + maxFM); } if (m_RangeMode == RANGE_BOUNDS) { m_LowThreshold = low; m_HighThreshold = high; //System.err.println("Threshold range: " + low + " - " + high); } } } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(5); newVector.addElement(new Option( "\tThe class for which threshold is determined. Valid values are:\n" + "\t1, 2 (for first and second classes, respectively), 3 (for whichever\n" + "\tclass is least frequent), and 4 (for whichever class value is most\n" + "\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n" + "\t\"1\", or method 3 if no matches). (default 5).", "C", 1, "-C <integer>")); newVector.addElement(new Option( "\tNumber of folds used for cross validation. If just a\n" + "\thold-out set is used, this determines the size of the hold-out set\n" + "\t(default 3).", "X", 1, "-X <number of folds>")); newVector.addElement(new Option( "\tSets whether confidence range correction is applied. This\n" + "\tcan be used to ensure the confidences range from 0 to 1.\n" + "\tUse 0 for no range correction, 1 for correction based on\n" + "\tthe min/max values seen during threshold selection\n"+ "\t(default 0).", "R", 1, "-R <integer>")); newVector.addElement(new Option( "\tSets the evaluation mode. Use 0 for\n" + "\tevaluation using cross-validation,\n" + "\t1 for evaluation using hold-out set,\n" + "\tand 2 for evaluation on the\n" + "\ttraining data (default 1).", "E", 1, "-E <integer>")); newVector.addElement(new Option( "\tMeasure used for evaluation (default is FMEASURE).\n", "M", 1, "-M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -C <integer> * The class for which threshold is determined. Valid values are: * 1, 2 (for first and second classes, respectively), 3 (for whichever * class is least frequent), and 4 (for whichever class value is most * frequent), and 5 (for the first class named any of "yes","pos(itive)" * "1", or method 3 if no matches). (default 5).</pre> * * <pre> -X <number of folds> * Number of folds used for cross validation. If just a * hold-out set is used, this determines the size of the hold-out set * (default 3).</pre> * * <pre> -R <integer> * Sets whether confidence range correction is applied. This * can be used to ensure the confidences range from 0 to 1. * Use 0 for no range correction, 1 for correction based on * the min/max values seen during threshold selection * (default 0).</pre> * * <pre> -E <integer> * Sets the evaluation mode. Use 0 for * evaluation using cross-validation, * 1 for evaluation using hold-out set, * and 2 for evaluation on the * training data (default 1).</pre> * * <pre> -M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL] * Measure used for evaluation (default is FMEASURE). * </pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.functions.Logistic)</pre> * * <pre> * Options specific to classifier weka.classifiers.functions.Logistic: * </pre> * * <pre> -D * Turn on debugging output.</pre> * * <pre> -R <ridge> * Set the ridge in the log-likelihood.</pre> * * <pre> -M <number> * Set the maximum number of iterations (default -1, until convergence).</pre> * <!-- options-end --> * * Options after -- are passed to the designated sub-classifier. <p> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String classString = Utils.getOption('C', options); if (classString.length() != 0) { setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1, TAGS_OPTIMIZE)); } else { setDesignatedClass(new SelectedTag(OPTIMIZE_LFREQ, TAGS_OPTIMIZE)); } String modeString = Utils.getOption('E', options); if (modeString.length() != 0) { setEvaluationMode(new SelectedTag(Integer.parseInt(modeString), TAGS_EVAL)); } else { setEvaluationMode(new SelectedTag(EVAL_TUNED_SPLIT, TAGS_EVAL)); } String rangeString = Utils.getOption('R', options); if (rangeString.length() != 0) { setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString), TAGS_RANGE)); } else { setRangeCorrection(new SelectedTag(RANGE_NONE, TAGS_RANGE)); } String measureString = Utils.getOption('M', options); if (measureString.length() != 0) { setMeasure(new SelectedTag(measureString, TAGS_MEASURE)); } else { setMeasure(new SelectedTag(FMEASURE, TAGS_MEASURE)); } String foldsString = Utils.getOption('X', options); if (foldsString.length() != 0) { setNumXValFolds(Integer.parseInt(foldsString)); } else { setNumXValFolds(3); } super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] superOptions = super.getOptions(); String [] options = new String [superOptions.length + 10]; int current = 0; options[current++] = "-C"; options[current++] = "" + (m_DesignatedClass + 1); options[current++] = "-X"; options[current++] = "" + getNumXValFolds(); options[current++] = "-E"; options[current++] = "" + m_EvalMode; options[current++] = "-R"; options[current++] = "" + m_RangeMode; options[current++] = "-M"; options[current++] = "" + getMeasure().getSelectedTag().getReadable(); System.arraycopy(superOptions, 0, options, current, superOptions.length); current += superOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.BINARY_CLASS); return result; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); AttributeStats stats = instances.attributeStats(instances.classIndex()); m_BestThreshold = 0.5; m_BestValue = MIN_VALUE; m_HighThreshold = 1; m_LowThreshold = 0; // If data contains only one instance of positive data // optimize on training data if (stats.distinctCount != 2) { System.err.println("Couldn't find examples of both classes. No adjustment."); m_Classifier.buildClassifier(instances); } else { // Determine which class value to look for switch (m_ClassMode) { case OPTIMIZE_0: m_DesignatedClass = 0; break; case OPTIMIZE_1: m_DesignatedClass = 1; break; case OPTIMIZE_POS_NAME: Attribute cAtt = instances.classAttribute(); boolean found = false; for (int i = 0; i < cAtt.numValues() && !found; i++) { String name = cAtt.value(i).toLowerCase(); if (name.startsWith("yes") || name.equals("1") || name.startsWith("pos")) { found = true; m_DesignatedClass = i; } } if (found) { break; } // No named class found, so fall through to default of least frequent case OPTIMIZE_LFREQ: m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 1 : 0; break; case OPTIMIZE_MFREQ: m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 0 : 1;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -