decisiontable.java

来自「Weka」· Java 代码 · 共 1,398 行 · 第 1/3 页

JAVA
1,398
字号
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    DecisionTable.java *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand * */package weka.classifiers.rules;import weka.attributeSelection.ASSearch;import weka.attributeSelection.BestFirst;import weka.attributeSelection.SubsetEvaluator;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.lazy.IBk;import weka.core.AdditionalMeasureProducer;import weka.core.Capabilities;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.SelectedTag;import weka.core.Tag;import weka.core.TechnicalInformation;import weka.core.TechnicalInformationHandler;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;import weka.core.TechnicalInformation.Field;import weka.core.TechnicalInformation.Type;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;import java.util.Arrays;import java.util.BitSet;import java.util.Enumeration;import java.util.Hashtable;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Class for building and using a simple decision table majority classifier.<br/> * <br/> * For more information see: <br/> * <br/> * Ron Kohavi: The Power of Decision Tables. In: 8th European Conference on Machine Learning, 174-189, 1995. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * &#64;inproceedings{Kohavi1995, *    author = {Ron Kohavi}, *    booktitle = {8th European Conference on Machine Learning}, *    pages = {174-189}, *    publisher = {Springer}, *    title = {The Power of Decision Tables}, *    year = {1995} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> *  * <pre> -S &lt;search method specification&gt; *  Full class name of search method, followed *  by its options. *  eg: "weka.attributeSelection.BestFirst -D 1" *  (default weka.attributeSelection.BestFirst)</pre> *  * <pre> -X &lt;number of folds&gt; *  Use cross validation to evaluate features. *  Use number of folds = 1 for leave one out CV. *  (Default = leave one out CV)</pre> *  * <pre> -E &lt;acc | rmse | mae | auc&gt; *  Performance evaluation measure to use for selecting attributes. *  (Default = accuracy for discrete class and rmse for numeric class)</pre> *  * <pre> -I *  Use nearest neighbour instead of global table majority.</pre> *  * <pre> -R *  Display decision table rules. * </pre> *  * <pre>  * Options specific to search method weka.attributeSelection.BestFirst: * </pre> *  * <pre> -P &lt;start set&gt; *  Specify a starting set of attributes. *  Eg. 1,3,5-7.</pre> *  * <pre> -D &lt;0 = backward | 1 = forward | 2 = bi-directional&gt; *  Direction of search. (default = 1).</pre> *  * <pre> -N &lt;num&gt; *  Number of non-improving nodes to *  consider before terminating search.</pre> *  * <pre> -S &lt;num&gt; *  Size of lookup cache for evaluated subsets. *  Expressed as a multiple of the number of *  attributes in the data set. (default = 1)</pre> *  <!-- options-end --> * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 1.44 $  */public class DecisionTable   extends Classifier   implements OptionHandler, WeightedInstancesHandler,              AdditionalMeasureProducer, TechnicalInformationHandler {  /** for serialization */  static final long serialVersionUID = 2888557078165701326L;  /** The hashtable used to hold training instances */  protected Hashtable m_entries;  /** The class priors to use when there is no match in the table */  protected double [] m_classPriorCounts;  protected double [] m_classPriors;  /** Holds the final feature set */  protected int [] m_decisionFeatures;  /** Discretization filter */  protected Filter m_disTransform;  /** Filter used to remove columns discarded by feature selection */  protected Remove m_delTransform;  /** IB1 used to classify non matching instances rather than majority class */  protected IBk m_ibk;  /** Holds the original training instances */  protected Instances m_theInstances;  /** Holds the final feature selected set of instances */  protected Instances m_dtInstances;  /** The number of attributes in the dataset */  protected int m_numAttributes;  /** The number of instances in the dataset */  private int m_numInstances;  /** Class is nominal */  protected boolean m_classIsNominal;  /** Use the IBk classifier rather than majority class */  protected boolean m_useIBk;  /** Display Rules */  protected boolean m_displayRules;  /** Number of folds for cross validating feature sets */  private int m_CVFolds;  /** Random numbers for use in cross validation */  private Random m_rr;  /** Holds the majority class */  protected double m_majority;  /** The search method to use */  protected ASSearch m_search = new BestFirst();  /** Our own internal evaluator */  protected SubsetEvaluator m_evaluator;  /** The evaluation object used to evaluate subsets */  protected Evaluation m_evaluation;  /** default is accuracy for discrete class and RMSE for numeric class */  public static final int EVAL_DEFAULT = 1;  public static final int EVAL_ACCURACY = 2;  public static final int EVAL_RMSE = 3;  public static final int EVAL_MAE = 4;  public static final int EVAL_AUC = 5;  public static final Tag [] TAGS_EVALUATION = {    new Tag(EVAL_DEFAULT, "Default: accuracy (discrete class); RMSE (numeric class)"),    new Tag(EVAL_ACCURACY, "Accuracy (discrete class only"),    new Tag(EVAL_RMSE, "RMSE (of the class probabilities for discrete class)"),    new Tag(EVAL_MAE, "MAE (of the class probabilities for discrete class)"),    new Tag(EVAL_AUC, "AUC (area under the ROC curve - discrete class only)")  };  protected int m_evaluationMeasure = EVAL_DEFAULT;  /**   * Returns a string describing classifier   * @return a description suitable for   * displaying in the explorer/experimenter gui   */  public String globalInfo() {    return      "Class for building and using a simple decision table majority "    + "classifier.\n\n"    + "For more information see: \n\n"    + getTechnicalInformation().toString();  }  /**   * Returns an instance of a TechnicalInformation object, containing    * detailed information about the technical background of this class,   * e.g., paper reference or book this class is based on.   *    * @return the technical information about this class   */  public TechnicalInformation getTechnicalInformation() {    TechnicalInformation 	result;    result = new TechnicalInformation(Type.INPROCEEDINGS);    result.setValue(Field.AUTHOR, "Ron Kohavi");    result.setValue(Field.TITLE, "The Power of Decision Tables");    result.setValue(Field.BOOKTITLE, "8th European Conference on Machine Learning");    result.setValue(Field.YEAR, "1995");    result.setValue(Field.PAGES, "174-189");    result.setValue(Field.PUBLISHER, "Springer");    return result;  }    /**   * Inserts an instance into the hash table   *   * @param inst instance to be inserted   * @param instA to create the hash key from   * @throws Exception if the instance can't be inserted   */  private void insertIntoTable(Instance inst, double [] instA)  throws Exception {    double [] tempClassDist2;    double [] newDist;    DecisionTableHashKey thekey;    if (instA != null) {      thekey = new DecisionTableHashKey(instA);    } else {      thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false);    }    // see if this one is already in the table    tempClassDist2 = (double []) m_entries.get(thekey);    if (tempClassDist2 == null) {      if (m_classIsNominal) {	newDist = new double [m_theInstances.classAttribute().numValues()];		//Leplace estimation	for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) {	  newDist[i] = 1.0;	}		newDist[(int)inst.classValue()] = inst.weight();	// add to the table	m_entries.put(thekey, newDist);      } else {	newDist = new double [2];	newDist[0] = inst.classValue() * inst.weight();	newDist[1] = inst.weight();	// add to the table	m_entries.put(thekey, newDist);      }    } else {       // update the distribution for this instance      if (m_classIsNominal) {	tempClassDist2[(int)inst.classValue()]+=inst.weight();	// update the table	m_entries.put(thekey, tempClassDist2);      } else  {	tempClassDist2[0] += (inst.classValue() * inst.weight());	tempClassDist2[1] += inst.weight();	// update the table	m_entries.put(thekey, tempClassDist2);      }    }  }  /**   * Classifies an instance for internal leave one out cross validation   * of feature sets   *   * @param instance instance to be "left out" and classified   * @param instA feature values of the selected features for the instance   * @return the classification of the instance   * @throws Exception if something goes wrong   */  double evaluateInstanceLeaveOneOut(Instance instance, double [] instA)  throws Exception {    DecisionTableHashKey thekey;    double [] tempDist;    double [] normDist;    thekey = new DecisionTableHashKey(instA);    if (m_classIsNominal) {      // if this one is not in the table      if ((tempDist = (double [])m_entries.get(thekey)) == null) {	throw new Error("This should never happen!");      } else {	normDist = new double [tempDist.length];	System.arraycopy(tempDist,0,normDist,0,tempDist.length);	normDist[(int)instance.classValue()] -= instance.weight();	// update the table	// first check to see if the class counts are all zero now	boolean ok = false;	for (int i=0;i<normDist.length;i++) {	  if (Utils.gr(normDist[i],1.0)) {	    ok = true;	    break;	  }	}//	downdate the class prior counts	m_classPriorCounts[(int)instance.classValue()] -= 	  instance.weight();	double [] classPriors = m_classPriorCounts.clone();	Utils.normalize(classPriors);	if (!ok) { // majority class	  normDist = classPriors;	}	m_classPriorCounts[(int)instance.classValue()] += 	  instance.weight();	//if (ok) {	Utils.normalize(normDist);	if (m_evaluationMeasure == EVAL_AUC) {	  m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);							} else {	  m_evaluation.evaluateModelOnce(normDist, instance);	}	return Utils.maxIndex(normDist);	/*} else {	  normDist = new double [normDist.length];	  normDist[(int)m_majority] = 1.0;	  if (m_evaluationMeasure == EVAL_AUC) {	    m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);							  } else {	    m_evaluation.evaluateModelOnce(normDist, instance);	  }	  return m_majority;	} */      }      //      return Utils.maxIndex(tempDist);    } else {      // see if this one is already in the table      if ((tempDist = (double[])m_entries.get(thekey)) != null) {	normDist = new double [tempDist.length];	System.arraycopy(tempDist,0,normDist,0,tempDist.length);	normDist[0] -= (instance.classValue() * instance.weight());	normDist[1] -= instance.weight();	if (Utils.eq(normDist[1],0.0)) {	  double [] temp = new double[1];	  temp[0] = m_majority;	  m_evaluation.evaluateModelOnce(temp, instance);	  return m_majority;	} else {	  double [] temp = new double[1];	  temp[0] = normDist[0] / normDist[1];	  m_evaluation.evaluateModelOnce(temp, instance);	  return temp[0];	}      } else {	throw new Error("This should never happen!");      }    }    // shouldn't get here     // return 0.0;  }  /**   * Calculates the accuracy on a test fold for internal cross validation   * of feature sets   *   * @param fold set of instances to be "left out" and classified   * @param fs currently selected feature set   * @return the accuracy for the fold   * @throws Exception if something goes wrong   */  double evaluateFoldCV(Instances fold, int [] fs) throws Exception {    int i;    int ruleCount = 0;    int numFold = fold.numInstances();    int numCl = m_theInstances.classAttribute().numValues();    double [][] class_distribs = new double [numFold][numCl];    double [] instA = new double [fs.length];    double [] normDist;    DecisionTableHashKey thekey;    double acc = 0.0;    int classI = m_theInstances.classIndex();    Instance inst;    if (m_classIsNominal) {      normDist = new double [numCl];    } else {      normDist = new double [2];    }    // first *remove* instances    for (i=0;i<numFold;i++) {      inst = fold.instance(i);      for (int j=0;j<fs.length;j++) {	if (fs[j] == classI) {	  instA[j] = Double.MAX_VALUE; // missing for the class	} else if (inst.isMissing(fs[j])) {	  instA[j] = Double.MAX_VALUE;	} else{	  instA[j] = inst.value(fs[j]);	}      }      thekey = new DecisionTableHashKey(instA);      if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {	throw new Error("This should never happen!");      } else {	if (m_classIsNominal) {	  class_distribs[i][(int)inst.classValue()] -= inst.weight();	} else {	  class_distribs[i][0] -= (inst.classValue() * inst.weight());	  class_distribs[i][1] -= inst.weight();	}	ruleCount++;      }      m_classPriorCounts[(int)inst.classValue()] -= 	inst.weight();	    }    double [] classPriors = m_classPriorCounts.clone();    Utils.normalize(classPriors);    // now classify instances

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?