decisiontable.java
来自「Weka」· Java 代码 · 共 1,398 行 · 第 1/3 页
JAVA
1,398 行
for (i=0;i<numFold;i++) { inst = fold.instance(i); System.arraycopy(class_distribs[i],0,normDist,0,normDist.length); if (m_classIsNominal) { boolean ok = false; for (int j=0;j<normDist.length;j++) { if (Utils.gr(normDist[j],1.0)) { ok = true; break; } } if (!ok) { // majority class normDist = classPriors.clone(); }// if (ok) { Utils.normalize(normDist); if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); } else { m_evaluation.evaluateModelOnce(normDist, inst); } /* } else { normDist[(int)m_majority] = 1.0; if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); } else { m_evaluation.evaluateModelOnce(normDist, inst); } } */ } else { if (Utils.eq(normDist[1],0.0)) { double [] temp = new double[1]; temp[0] = m_majority; m_evaluation.evaluateModelOnce(temp, inst); } else { double [] temp = new double[1]; temp[0] = normDist[0] / normDist[1]; m_evaluation.evaluateModelOnce(temp, inst); } } } // now re-insert instances for (i=0;i<numFold;i++) { inst = fold.instance(i); m_classPriorCounts[(int)inst.classValue()] += inst.weight(); if (m_classIsNominal) { class_distribs[i][(int)inst.classValue()] += inst.weight(); } else { class_distribs[i][0] += (inst.classValue() * inst.weight()); class_distribs[i][1] += inst.weight(); } } return acc; } /** * Evaluates a feature subset by cross validation * * @param feature_set the subset to be evaluated * @param num_atts the number of attributes in the subset * @return the estimated accuracy * @throws Exception if subset can't be evaluated */ protected double estimatePerformance(BitSet feature_set, int num_atts) throws Exception { m_evaluation = new Evaluation(m_theInstances); int i; int [] fs = new int [num_atts]; double [] instA = new double [num_atts]; int classI = m_theInstances.classIndex(); int index = 0; for (i=0;i<m_numAttributes;i++) { if (feature_set.get(i)) { fs[index++] = i; } } // create new hash table m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5)); // insert instances into the hash table for (i=0;i<m_numInstances;i++) { Instance inst = m_theInstances.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } insertIntoTable(inst, instA); } if (m_CVFolds == 1) { // calculate leave one out error for (i=0;i<m_numInstances;i++) { Instance inst = m_theInstances.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } evaluateInstanceLeaveOneOut(inst, instA); } } else { m_theInstances.randomize(m_rr); m_theInstances.stratify(m_CVFolds); // calculate 10 fold cross validation error for (i=0;i<m_CVFolds;i++) { Instances insts = m_theInstances.testCV(m_CVFolds,i); evaluateFoldCV(insts, fs); } } switch (m_evaluationMeasure) { case EVAL_DEFAULT: if (m_classIsNominal) { return m_evaluation.pctCorrect(); } return -m_evaluation.rootMeanSquaredError(); case EVAL_ACCURACY: return m_evaluation.pctCorrect(); case EVAL_RMSE: return -m_evaluation.rootMeanSquaredError(); case EVAL_MAE: return -m_evaluation.meanAbsoluteError(); case EVAL_AUC: double [] classPriors = m_evaluation.getClassPriors(); Utils.normalize(classPriors); double weightedAUC = 0; for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) { double tempAUC = m_evaluation.areaUnderROC(i); if (tempAUC != Instance.missingValue()) { weightedAUC += (classPriors[i] * tempAUC); } else { System.err.println("Undefined AUC!!"); } } return weightedAUC; } // shouldn't get here return 0.0; } /** * Returns a String representation of a feature subset * * @param sub BitSet representation of a subset * @return String containing subset */ private String printSub(BitSet sub) { String s=""; for (int jj=0;jj<m_numAttributes;jj++) { if (sub.get(jj)) { s += " "+(jj+1); } } return s; } /** * Resets the options. */ protected void resetOptions() { m_entries = null; m_decisionFeatures = null; m_useIBk = false; m_CVFolds = 1; m_displayRules = false; m_evaluationMeasure = EVAL_DEFAULT; } /** * Constructor for a DecisionTable */ public DecisionTable() { resetOptions(); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(7); newVector.addElement(new Option( "\tFull class name of search method, followed\n" + "\tby its options.\n" + "\teg: \"weka.attributeSelection.BestFirst -D 1\"\n" + "\t(default weka.attributeSelection.BestFirst)", "S", 1, "-S <search method specification>")); newVector.addElement(new Option( "\tUse cross validation to evaluate features.\n" + "\tUse number of folds = 1 for leave one out CV.\n" + "\t(Default = leave one out CV)", "X", 1, "-X <number of folds>")); newVector.addElement(new Option( "\tPerformance evaluation measure to use for selecting attributes.\n" + "\t(Default = accuracy for discrete class and rmse for numeric class)", "E", 1, "-E <acc | rmse | mae | auc>")); newVector.addElement(new Option( "\tUse nearest neighbour instead of global table majority.", "I", 0, "-I")); newVector.addElement(new Option( "\tDisplay decision table rules.\n", "R", 0, "-R")); newVector.addElement(new Option( "", "", 0, "\nOptions specific to search method " + m_search.getClass().getName() + ":")); Enumeration enu = ((OptionHandler)m_search).listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String crossValTipText() { return "Sets the number of folds for cross validation (1 = leave one out)."; } /** * Sets the number of folds for cross validation (1 = leave one out) * * @param folds the number of folds */ public void setCrossVal(int folds) { m_CVFolds = folds; } /** * Gets the number of folds for cross validation * * @return the number of cross validation folds */ public int getCrossVal() { return m_CVFolds; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String useIBkTipText() { return "Sets whether IBk should be used instead of the majority class."; } /** * Sets whether IBk should be used instead of the majority class * * @param ibk true if IBk is to be used */ public void setUseIBk(boolean ibk) { m_useIBk = ibk; } /** * Gets whether IBk is being used instead of the majority class * * @return true if IBk is being used */ public boolean getUseIBk() { return m_useIBk; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String displayRulesTipText() { return "Sets whether rules are to be printed."; } /** * Sets whether rules are to be printed * * @param rules true if rules are to be printed */ public void setDisplayRules(boolean rules) { m_displayRules = rules; } /** * Gets whether rules are being printed * * @return true if rules are being printed */ public boolean getDisplayRules() { return m_displayRules; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String searchTipText() { return "The search method used to find good attribute combinations for the " + "decision table."; } /** * Sets the search method to use * * @param search */ public void setSearch(ASSearch search) { m_search = search; } /** * Gets the current search method * * @return the search method used */ public ASSearch getSearch() { return m_search; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String evaluationMeasureTipText() { return "The measure used to evaluate the performance of attribute combinations " + "used in the decision table."; } /** * Gets the currently set performance evaluation measure used for selecting * attributes for the decision table * * @return the performance evaluation measure */ public SelectedTag getEvaluationMeasure() { return new SelectedTag(m_evaluationMeasure, TAGS_EVALUATION); } /** * Sets the performance evaluation measure to use for selecting attributes * for the decision table * * @param newMethod the new performance evaluation metric to use */ public void setEvaluationMeasure(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_EVALUATION) { m_evaluationMeasure = newMethod.getSelectedTag().getID(); } } /** * Parses the options for this object. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -S <search method specification> * Full class name of search method, followed * by its options. * eg: "weka.attributeSelection.BestFirst -D 1" * (default weka.attributeSelection.BestFirst)</pre> * * <pre> -X <number of folds> * Use cross validation to evaluate features. * Use number of folds = 1 for leave one out CV. * (Default = leave one out CV)</pre> * * <pre> -E <acc | rmse | mae | auc> * Performance evaluation measure to use for selecting attributes. * (Default = accuracy for discrete class and rmse for numeric class)</pre> * * <pre> -I * Use nearest neighbour instead of global table majority.</pre> * * <pre> -R * Display decision table rules. * </pre> * * <pre> * Options specific to search method weka.attributeSelection.BestFirst: * </pre> * * <pre> -P <start set> * Specify a starting set of attributes. * Eg. 1,3,5-7.</pre> * * <pre> -D <0 = backward | 1 = forward | 2 = bi-directional> * Direction of search. (default = 1).</pre> * * <pre> -N <num> * Number of non-improving nodes to * consider before terminating search.</pre> * * <pre> -S <num> * Size of lookup cache for evaluated subsets. * Expressed as a multiple of the number of * attributes in the data set. (default = 1)</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String optionString; resetOptions(); optionString = Utils.getOption('X',options); if (optionString.length() != 0) { m_CVFolds = Integer.parseInt(optionString); } m_useIBk = Utils.getFlag('I',options); m_displayRules = Utils.getFlag('R',options); optionString = Utils.getOption('E', options); if (optionString.length() != 0) { if (optionString.equals("acc")) { setEvaluationMeasure(new SelectedTag(EVAL_ACCURACY, TAGS_EVALUATION)); } else if (optionString.equals("rmse")) {
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?