📄 decisiontable.java
字号:
/*} else { normDist = new double [normDist.length]; normDist[(int)m_majority] = 1.0; if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance); } else { m_evaluation.evaluateModelOnce(normDist, instance); } return m_majority; } */ } // return Utils.maxIndex(tempDist); } else { // see if this one is already in the table if ((tempDist = (double[])m_entries.get(thekey)) != null) { normDist = new double [tempDist.length]; System.arraycopy(tempDist,0,normDist,0,tempDist.length); normDist[0] -= (instance.classValue() * instance.weight()); normDist[1] -= instance.weight(); if (Utils.eq(normDist[1],0.0)) { double [] temp = new double[1]; temp[0] = m_majority; m_evaluation.evaluateModelOnce(temp, instance); return m_majority; } else { double [] temp = new double[1]; temp[0] = normDist[0] / normDist[1]; m_evaluation.evaluateModelOnce(temp, instance); return temp[0]; } } else { throw new Error("This should never happen!"); } } // shouldn't get here // return 0.0; } /** * Calculates the accuracy on a test fold for internal cross validation * of feature sets * * @param fold set of instances to be "left out" and classified * @param fs currently selected feature set * @return the accuracy for the fold * @throws Exception if something goes wrong */ double evaluateFoldCV(Instances fold, int [] fs) throws Exception { int i; int ruleCount = 0; int numFold = fold.numInstances(); int numCl = m_theInstances.classAttribute().numValues(); double [][] class_distribs = new double [numFold][numCl]; double [] instA = new double [fs.length]; double [] normDist; hashKey thekey; double acc = 0.0; int classI = m_theInstances.classIndex(); Instance inst; if (m_classIsNominal) { normDist = new double [numCl]; } else { normDist = new double [2]; } // first *remove* instances for (i=0;i<numFold;i++) { inst = fold.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else{ instA[j] = inst.value(fs[j]); } } thekey = new hashKey(instA); if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) { throw new Error("This should never happen!"); } else { if (m_classIsNominal) { class_distribs[i][(int)inst.classValue()] -= inst.weight(); } else { class_distribs[i][0] -= (inst.classValue() * inst.weight()); class_distribs[i][1] -= inst.weight(); } ruleCount++; } m_classPriorCounts[(int)inst.classValue()] -= inst.weight(); } double [] classPriors = m_classPriorCounts.clone(); Utils.normalize(classPriors); // now classify instances for (i=0;i<numFold;i++) { inst = fold.instance(i); System.arraycopy(class_distribs[i],0,normDist,0,normDist.length); if (m_classIsNominal) { boolean ok = false; for (int j=0;j<normDist.length;j++) { if (!Utils.eq(normDist[j],0.0)) { ok = true; break; } } if (!ok) { // majority class normDist = classPriors.clone(); }// if (ok) { Utils.normalize(normDist); if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); } else { m_evaluation.evaluateModelOnce(normDist, inst); } /* } else { normDist[(int)m_majority] = 1.0; if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); } else { m_evaluation.evaluateModelOnce(normDist, inst); } } */ } else { if (Utils.eq(normDist[1],0.0)) { double [] temp = new double[1]; temp[0] = m_majority; m_evaluation.evaluateModelOnce(temp, inst); } else { double [] temp = new double[1]; temp[0] = normDist[0] / normDist[1]; m_evaluation.evaluateModelOnce(temp, inst); } } } // now re-insert instances for (i=0;i<numFold;i++) { inst = fold.instance(i); m_classPriorCounts[(int)inst.classValue()] += inst.weight(); if (m_classIsNominal) { class_distribs[i][(int)inst.classValue()] += inst.weight(); } else { class_distribs[i][0] += (inst.classValue() * inst.weight()); class_distribs[i][1] += inst.weight(); } } return acc; } /** * Evaluates a feature subset by cross validation * * @param feature_set the subset to be evaluated * @param num_atts the number of attributes in the subset * @return the estimated accuracy * @throws Exception if subset can't be evaluated */ protected double estimatePerformance(BitSet feature_set, int num_atts) throws Exception { m_evaluation = new Evaluation(m_theInstances); int i; int [] fs = new int [num_atts]; double [] instA = new double [num_atts]; int classI = m_theInstances.classIndex(); int index = 0; for (i=0;i<m_numAttributes;i++) { if (feature_set.get(i)) { fs[index++] = i; } } // create new hash table m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5)); // insert instances into the hash table for (i=0;i<m_numInstances;i++) { Instance inst = m_theInstances.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } insertIntoTable(inst, instA); } if (m_CVFolds == 1) { // calculate leave one out error for (i=0;i<m_numInstances;i++) { Instance inst = m_theInstances.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } evaluateInstanceLeaveOneOut(inst, instA); } } else { m_theInstances.randomize(m_rr); m_theInstances.stratify(m_CVFolds); // calculate 10 fold cross validation error for (i=0;i<m_CVFolds;i++) { Instances insts = m_theInstances.testCV(m_CVFolds,i); evaluateFoldCV(insts, fs); } } switch (m_evaluationMeasure) { case EVAL_DEFAULT: if (m_classIsNominal) { return m_evaluation.pctCorrect(); } return -m_evaluation.rootMeanSquaredError(); case EVAL_ACCURACY: return m_evaluation.pctCorrect(); case EVAL_RMSE: return -m_evaluation.rootMeanSquaredError(); case EVAL_MAE: return -m_evaluation.meanAbsoluteError(); case EVAL_AUC: double [] classPriors = m_evaluation.getClassPriors(); Utils.normalize(classPriors); double weightedAUC = 0; for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) { double tempAUC = m_evaluation.areaUnderROC(i); if (tempAUC != Instance.missingValue()) { weightedAUC += (classPriors[i] * tempAUC); } else { System.err.println("Undefined AUC!!"); } } return weightedAUC; } // shouldn't get here return 0.0; } /** * Returns a String representation of a feature subset * * @param sub BitSet representation of a subset * @return String containing subset */ private String printSub(BitSet sub) { String s=""; for (int jj=0;jj<m_numAttributes;jj++) { if (sub.get(jj)) { s += " "+(jj+1); } } return s; } /** * Resets the options. */ protected void resetOptions() { m_entries = null; m_decisionFeatures = null; m_useIBk = false; m_CVFolds = 1; m_displayRules = false; m_evaluationMeasure = EVAL_DEFAULT; } /** * Constructor for a DecisionTable */ public DecisionTable() { resetOptions(); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(7); newVector.addElement(new Option( "\tFull class name of search method, followed\n" + "\tby its options.\n" + "\teg: \"weka.attributeSelection.BestFirst -D 1\"\n" + "\t(default weka.attributeSelection.BestFirst)", "S", 1, "-S <search method specification>")); newVector.addElement(new Option( "\tUse cross validation to evaluate features.\n" + "\tUse number of folds = 1 for leave one out CV.\n" + "\t(Default = leave one out CV)", "X", 1, "-X <number of folds>")); newVector.addElement(new Option( "\tPerformance evaluation measure to use for selecting attributes.\n" + "\t(Default = accuracy for discrete class and rmse for numeric class)", "E", 1, "-E <acc | rmse | mae | auc>")); newVector.addElement(new Option( "\tUse nearest neighbour instead of global table majority.", "I", 0, "-I")); newVector.addElement(new Option( "\tDisplay decision table rules.\n", "R", 0, "-R")); newVector.addElement(new Option( "", "", 0, "\nOptions specific to search method " + m_search.getClass().getName() + ":")); Enumeration enu = ((OptionHandler)m_search).listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String crossValTipText() { return "Sets the number of folds for cross validation (1 = leave one out)."; } /** * Sets the number of folds for cross validation (1 = leave one out) * * @param folds the number of folds */ public void setCrossVal(int folds) { m_CVFolds = folds; } /** * Gets the number of folds for cross validation * * @return the number of cross validation folds */ public int getCrossVal() { return m_CVFolds; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String useIBkTipText() { return "Sets whether IBk should be used instead of the majority class."; } /** * Sets whether IBk should be used instead of the majority class * * @param ibk true if IBk is to be used */ public void setUseIBk(boolean ibk) { m_useIBk = ibk; } /** * Gets whether IBk is being used instead of the majority class * * @return true if IBk is being used */ public boolean getUseIBk() { return m_useIBk; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String displayRulesTipText() { return "Sets whether rules are to be printed."; } /** * Sets whether rules are to be printed * * @param rules true if rules are to be printed */ public void setDisplayRules(boolean rules) { m_displayRules = rules; } /** * Gets whether rules are being printed * * @return true if rules are being printed */ public boolean getDisplayRules() { return m_displayRules; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String searchTipText() { return "The search method used to find good attribute combinations for the " + "decision table."; } /** * Sets the search method to use * * @param search */ public void setSearch(ASSearch search) { m_search = search; } /** * Gets the current search method * * @return the search method used */ public ASSearch getSearch() { return m_search; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String evaluationMeasureTipText() { return "The measure used to evaluate the performance of attribute combinations " + "used in the decision table."; } /** * Gets the currently set performance evaluation measure used for selecting * attributes for the decision table * * @return the performance evaluation measure */ public SelectedTag getEvaluationMeasure() { return new SelectedTag(m_evaluationMeasure, TAGS_EVALUATION); } /** * Sets the performance evaluation measure to use for selecting attributes * for the decision table * * @param newMethod the new performance evaluation metric to use */ public void setEvaluationMeasure(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_EVALUATION) { m_evaluationMeasure = newMethod.getSelectedTag().getID(); } } /** * Parses the options for this object. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -S <search method specification> * Full class name of search method, followed * by its options. * eg: "weka.attributeSelection.BestFirst -D 1"
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -