decisiontable.java

来自「Weka」· Java 代码 · 共 1,398 行 · 第 1/3 页

JAVA
1,398
字号
    for (i=0;i<numFold;i++) {      inst = fold.instance(i);      System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);      if (m_classIsNominal) {	boolean ok = false;	for (int j=0;j<normDist.length;j++) {	  if (Utils.gr(normDist[j],1.0)) {	    ok = true;	    break;	  }	}	if (!ok) { // majority class	  normDist = classPriors.clone();	}//	if (ok) {	Utils.normalize(normDist);	if (m_evaluationMeasure == EVAL_AUC) {	  m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);							} else {	  m_evaluation.evaluateModelOnce(normDist, inst);	}	/*	} else {						  normDist[(int)m_majority] = 1.0;	  if (m_evaluationMeasure == EVAL_AUC) {	    m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);							  } else {	    m_evaluation.evaluateModelOnce(normDist, inst);						  }	} */      } else {	if (Utils.eq(normDist[1],0.0)) {	  double [] temp = new double[1];	  temp[0] = m_majority;	  m_evaluation.evaluateModelOnce(temp, inst);	} else {	  double [] temp = new double[1];	  temp[0] = normDist[0] / normDist[1];	  m_evaluation.evaluateModelOnce(temp, inst);	}      }    }    // now re-insert instances    for (i=0;i<numFold;i++) {      inst = fold.instance(i);      m_classPriorCounts[(int)inst.classValue()] += 	inst.weight();      if (m_classIsNominal) {	class_distribs[i][(int)inst.classValue()] += inst.weight();      } else {	class_distribs[i][0] += (inst.classValue() * inst.weight());	class_distribs[i][1] += inst.weight();      }    }    return acc;  }  /**   * Evaluates a feature subset by cross validation   *   * @param feature_set the subset to be evaluated   * @param num_atts the number of attributes in the subset   * @return the estimated accuracy   * @throws Exception if subset can't be evaluated   */  protected double estimatePerformance(BitSet feature_set, int num_atts)  throws Exception {    m_evaluation = new Evaluation(m_theInstances);    int i;    int [] fs = new int [num_atts];    double [] instA = new double [num_atts];    int classI = m_theInstances.classIndex();    int index = 0;    for (i=0;i<m_numAttributes;i++) {      if (feature_set.get(i)) {	fs[index++] = i;      }    }    // create new hash table    m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));    // insert instances into the hash table    for (i=0;i<m_numInstances;i++) {      Instance inst = m_theInstances.instance(i);      for (int j=0;j<fs.length;j++) {	if (fs[j] == classI) {	  instA[j] = Double.MAX_VALUE; // missing for the class	} else if (inst.isMissing(fs[j])) {	  instA[j] = Double.MAX_VALUE;	} else {	  instA[j] = inst.value(fs[j]);	}      }      insertIntoTable(inst, instA);    }    if (m_CVFolds == 1) {      // calculate leave one out error      for (i=0;i<m_numInstances;i++) {	Instance inst = m_theInstances.instance(i);	for (int j=0;j<fs.length;j++) {	  if (fs[j] == classI) {	    instA[j] = Double.MAX_VALUE; // missing for the class	  } else if (inst.isMissing(fs[j])) {	    instA[j] = Double.MAX_VALUE;	  } else {	    instA[j] = inst.value(fs[j]);	  }	}	evaluateInstanceLeaveOneOut(inst, instA);				      }    } else {      m_theInstances.randomize(m_rr);      m_theInstances.stratify(m_CVFolds);      // calculate 10 fold cross validation error      for (i=0;i<m_CVFolds;i++) {	Instances insts = m_theInstances.testCV(m_CVFolds,i);	evaluateFoldCV(insts, fs);      }    }    switch (m_evaluationMeasure) {    case EVAL_DEFAULT:      if (m_classIsNominal) {	return m_evaluation.pctCorrect();      }      return -m_evaluation.rootMeanSquaredError();    case EVAL_ACCURACY:      return m_evaluation.pctCorrect();    case EVAL_RMSE:      return -m_evaluation.rootMeanSquaredError();    case EVAL_MAE:      return -m_evaluation.meanAbsoluteError();    case EVAL_AUC:      double [] classPriors = m_evaluation.getClassPriors();      Utils.normalize(classPriors);      double weightedAUC = 0;      for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) {	double tempAUC = m_evaluation.areaUnderROC(i);	if (tempAUC != Instance.missingValue()) {	  weightedAUC += (classPriors[i] * tempAUC);	} else {	  System.err.println("Undefined AUC!!");	}      }      return weightedAUC;    }    // shouldn't get here    return 0.0;  }  /**   * Returns a String representation of a feature subset   *   * @param sub BitSet representation of a subset   * @return String containing subset   */  private String printSub(BitSet sub) {    String s="";    for (int jj=0;jj<m_numAttributes;jj++) {      if (sub.get(jj)) {	s += " "+(jj+1);      }    }    return s;  }  /**   * Resets the options.   */  protected void resetOptions()  {    m_entries = null;    m_decisionFeatures = null;    m_useIBk = false;    m_CVFolds = 1;    m_displayRules = false;    m_evaluationMeasure = EVAL_DEFAULT;  }  /**   * Constructor for a DecisionTable   */  public DecisionTable() {    resetOptions();  }  /**   * Returns an enumeration describing the available options.   *   * @return an enumeration of all the available options.   */  public Enumeration listOptions() {    Vector newVector = new Vector(7);    newVector.addElement(new Option(	"\tFull class name of search method, followed\n"	+ "\tby its options.\n"	+ "\teg: \"weka.attributeSelection.BestFirst -D 1\"\n"	+ "\t(default weka.attributeSelection.BestFirst)",	"S", 1, "-S <search method specification>"));    newVector.addElement(new Option(	"\tUse cross validation to evaluate features.\n" +	"\tUse number of folds = 1 for leave one out CV.\n" +	"\t(Default = leave one out CV)",	"X", 1, "-X <number of folds>"));    newVector.addElement(new Option(	"\tPerformance evaluation measure to use for selecting attributes.\n" +	"\t(Default = accuracy for discrete class and rmse for numeric class)",	"E", 1, "-E <acc | rmse | mae | auc>"));    newVector.addElement(new Option(	"\tUse nearest neighbour instead of global table majority.",	"I", 0, "-I"));    newVector.addElement(new Option(	"\tDisplay decision table rules.\n",	"R", 0, "-R"));     newVector.addElement(new Option(	"",	"", 0, "\nOptions specific to search method "	+ m_search.getClass().getName() + ":"));    Enumeration enu = ((OptionHandler)m_search).listOptions();    while (enu.hasMoreElements()) {      newVector.addElement(enu.nextElement());    }    return newVector.elements();  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String crossValTipText() {    return "Sets the number of folds for cross validation (1 = leave one out).";  }  /**   * Sets the number of folds for cross validation (1 = leave one out)   *   * @param folds the number of folds   */  public void setCrossVal(int folds) {    m_CVFolds = folds;  }  /**   * Gets the number of folds for cross validation   *   * @return the number of cross validation folds   */  public int getCrossVal() {    return m_CVFolds;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String useIBkTipText() {    return "Sets whether IBk should be used instead of the majority class.";  }  /**   * Sets whether IBk should be used instead of the majority class   *   * @param ibk true if IBk is to be used   */  public void setUseIBk(boolean ibk) {    m_useIBk = ibk;  }  /**   * Gets whether IBk is being used instead of the majority class   *   * @return true if IBk is being used   */  public boolean getUseIBk() {    return m_useIBk;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String displayRulesTipText() {    return "Sets whether rules are to be printed.";  }  /**   * Sets whether rules are to be printed   *   * @param rules true if rules are to be printed   */  public void setDisplayRules(boolean rules) {    m_displayRules = rules;  }  /**   * Gets whether rules are being printed   *   * @return true if rules are being printed   */  public boolean getDisplayRules() {    return m_displayRules;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String searchTipText() {    return "The search method used to find good attribute combinations for the "    + "decision table.";  }  /**   * Sets the search method to use   *    * @param search   */  public void setSearch(ASSearch search) {    m_search = search;  }  /**   * Gets the current search method   *    * @return the search method used   */  public ASSearch getSearch() {    return m_search;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String evaluationMeasureTipText() {    return "The measure used to evaluate the performance of attribute combinations "    + "used in the decision table.";  }  /**   * Gets the currently set performance evaluation measure used for selecting   * attributes for the decision table   *    * @return the performance evaluation measure   */  public SelectedTag getEvaluationMeasure() {    return new SelectedTag(m_evaluationMeasure, TAGS_EVALUATION);  }  /**   * Sets the performance evaluation measure to use for selecting attributes   * for the decision table   *    * @param newMethod the new performance evaluation metric to use   */  public void setEvaluationMeasure(SelectedTag newMethod) {    if (newMethod.getTags() == TAGS_EVALUATION) {      m_evaluationMeasure = newMethod.getSelectedTag().getID();    }  }  /**   * Parses the options for this object. <p/>   *   <!-- options-start -->   * Valid options are: <p/>   *    * <pre> -S &lt;search method specification&gt;   *  Full class name of search method, followed   *  by its options.   *  eg: "weka.attributeSelection.BestFirst -D 1"   *  (default weka.attributeSelection.BestFirst)</pre>   *    * <pre> -X &lt;number of folds&gt;   *  Use cross validation to evaluate features.   *  Use number of folds = 1 for leave one out CV.   *  (Default = leave one out CV)</pre>   *    * <pre> -E &lt;acc | rmse | mae | auc&gt;   *  Performance evaluation measure to use for selecting attributes.   *  (Default = accuracy for discrete class and rmse for numeric class)</pre>   *    * <pre> -I   *  Use nearest neighbour instead of global table majority.</pre>   *    * <pre> -R   *  Display decision table rules.   * </pre>   *    * <pre>    * Options specific to search method weka.attributeSelection.BestFirst:   * </pre>   *    * <pre> -P &lt;start set&gt;   *  Specify a starting set of attributes.   *  Eg. 1,3,5-7.</pre>   *    * <pre> -D &lt;0 = backward | 1 = forward | 2 = bi-directional&gt;   *  Direction of search. (default = 1).</pre>   *    * <pre> -N &lt;num&gt;   *  Number of non-improving nodes to   *  consider before terminating search.</pre>   *    * <pre> -S &lt;num&gt;   *  Size of lookup cache for evaluated subsets.   *  Expressed as a multiple of the number of   *  attributes in the data set. (default = 1)</pre>   *    <!-- options-end -->   *   * @param options the list of options as an array of strings   * @throws Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {    String optionString;    resetOptions();    optionString = Utils.getOption('X',options);    if (optionString.length() != 0) {      m_CVFolds = Integer.parseInt(optionString);    }    m_useIBk = Utils.getFlag('I',options);    m_displayRules = Utils.getFlag('R',options);    optionString = Utils.getOption('E', options);    if (optionString.length() != 0) {      if (optionString.equals("acc")) {	setEvaluationMeasure(new SelectedTag(EVAL_ACCURACY, TAGS_EVALUATION));      } else if (optionString.equals("rmse")) {

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?