📄 wrappersubseteval.java
字号:
* @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String foldsTipText() { return "Number of xval folds to use when estimating subset accuracy."; } /** * Set the number of folds to use for accuracy estimation * * @param f the number of folds */ public void setFolds (int f) { m_folds = f; } /** * Get the number of folds used for accuracy estimation * * @return the number of folds */ public int getFolds () { return m_folds; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String seedTipText() { return "Seed to use for randomly generating xval splits."; } /** * Set the seed to use for cross validation * * @param s the seed */ public void setSeed (int s) { m_seed = s; } /** * Get the random number seed used for cross validation * * @return the seed */ public int getSeed () { return m_seed; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String classifierTipText() { return "Classifier to use for estimating the accuracy of subsets"; } /** * Set the classifier to use for accuracy estimation * * @param newClassifier the Classifier to use. */ public void setClassifier (Classifier newClassifier) { m_BaseClassifier = newClassifier; } /** * Get the classifier used as the base learner. * * @return the classifier used as the classifier */ public Classifier getClassifier () { return m_BaseClassifier; } /** * Gets the current settings of WrapperSubsetEval. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { String[] classifierOptions = new String[0]; if ((m_BaseClassifier != null) && (m_BaseClassifier instanceof OptionHandler)) { classifierOptions = ((OptionHandler)m_BaseClassifier).getOptions(); } String[] options = new String[9 + classifierOptions.length]; int current = 0; if (getClassifier() != null) { options[current++] = "-B"; options[current++] = getClassifier().getClass().getName(); } options[current++] = "-F"; options[current++] = "" + getFolds(); options[current++] = "-T"; options[current++] = "" + getThreshold(); options[current++] = "-R"; options[current++] = "" + getSeed(); options[current++] = "--"; System.arraycopy(classifierOptions, 0, options, current, classifierOptions.length); current += classifierOptions.length; while (current < options.length) { options[current++] = ""; } return options; } protected void resetOptions () { m_trainInstances = null; m_Evaluation = null; m_BaseClassifier = new ZeroR(); m_folds = 5; m_seed = 1; m_threshold = 0.01; } /** * Returns the capabilities of this evaluator. * * @return the capabilities of this evaluator * @see Capabilities */ public Capabilities getCapabilities() { Capabilities result; if (getClassifier() == null) result = super.getCapabilities(); else result = getClassifier().getCapabilities(); // set dependencies for (Capability cap: Capability.values()) result.enableDependency(cap); result.setMinimumNumberInstances(getFolds()); return result; } /** * Generates a attribute evaluator. Has to initialize all fields of the * evaluator that are not being set via options. * * @param data set of instances serving as training data * @throws Exception if the evaluator has not been * generated successfully */ public void buildEvaluator (Instances data) throws Exception { // can evaluator handle data? getCapabilities().testWithFail(data); m_trainInstances = data; m_classIndex = m_trainInstances.classIndex(); m_numAttribs = m_trainInstances.numAttributes(); m_numInstances = m_trainInstances.numInstances(); } /** * Evaluates a subset of attributes * * @param subset a bitset representing the attribute subset to be * evaluated * @return the error rate * @throws Exception if the subset could not be evaluated */ public double evaluateSubset (BitSet subset) throws Exception { double errorRate = 0; double[] repError = new double[5]; int numAttributes = 0; int i, j; Random Rnd = new Random(m_seed); Remove delTransform = new Remove(); delTransform.setInvertSelection(true); // copy the instances Instances trainCopy = new Instances(m_trainInstances); // count attributes set in the BitSet for (i = 0; i < m_numAttribs; i++) { if (subset.get(i)) { numAttributes++; } } // set up an array of attribute indexes for the filter (+1 for the class) int[] featArray = new int[numAttributes + 1]; for (i = 0, j = 0; i < m_numAttribs; i++) { if (subset.get(i)) { featArray[j++] = i; } } featArray[j] = m_classIndex; delTransform.setAttributeIndicesArray(featArray); delTransform.setInputFormat(trainCopy); trainCopy = Filter.useFilter(trainCopy, delTransform); // max of 5 repititions ofcross validation for (i = 0; i < 5; i++) { m_Evaluation = new Evaluation(trainCopy); m_Evaluation.crossValidateModel(m_BaseClassifier, trainCopy, m_folds, Rnd); repError[i] = m_Evaluation.errorRate(); // check on the standard deviation if (!repeat(repError, i + 1)) { i++; break; } } for (j = 0; j < i; j++) { errorRate += repError[j]; } errorRate /= (double)i; m_Evaluation = null; return -errorRate; } /** * Returns a string describing the wrapper * * @return the description as a string */ public String toString () { StringBuffer text = new StringBuffer(); if (m_trainInstances == null) { text.append("\tWrapper subset evaluator has not been built yet\n"); } else { text.append("\tWrapper Subset Evaluator\n"); text.append("\tLearning scheme: " + getClassifier().getClass().getName() + "\n"); text.append("\tScheme options: "); String[] classifierOptions = new String[0]; if (m_BaseClassifier instanceof OptionHandler) { classifierOptions = ((OptionHandler)m_BaseClassifier).getOptions(); for (int i = 0; i < classifierOptions.length; i++) { text.append(classifierOptions[i] + " "); } } text.append("\n"); if (m_trainInstances.attribute(m_classIndex).isNumeric()) { text.append("\tAccuracy estimation: RMSE\n"); } else { text.append("\tAccuracy estimation: classification error\n"); } text.append("\tNumber of folds for accuracy estimation: " + m_folds + "\n"); } return text.toString(); } /** * decides whether to do another repeat of cross validation. If the * standard deviation of the cross validations * is greater than threshold% of the mean (default 1%) then another * repeat is done. * * @param repError an array of cross validation results * @param entries the number of cross validations done so far * @return true if another cv is to be done */ private boolean repeat (double[] repError, int entries) { int i; double mean = 0; double variance = 0; if (entries == 1) { return true; } for (i = 0; i < entries; i++) { mean += repError[i]; } mean /= (double)entries; for (i = 0; i < entries; i++) { variance += ((repError[i] - mean)*(repError[i] - mean)); } variance /= (double)entries; if (variance > 0) { variance = Math.sqrt(variance); } if ((variance/mean) > m_threshold) { return true; } return false; } /** * Main method for testing this class. * * @param args the options */ public static void main (String[] args) { runEvaluator(new WrapperSubsetEval(), args); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -