decisiontable.java
来自「Weka」· Java 代码 · 共 1,398 行 · 第 1/3 页
JAVA
1,398 行
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * DecisionTable.java * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand * */package weka.classifiers.rules;import weka.attributeSelection.ASSearch;import weka.attributeSelection.BestFirst;import weka.attributeSelection.SubsetEvaluator;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.lazy.IBk;import weka.core.AdditionalMeasureProducer;import weka.core.Capabilities;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.SelectedTag;import weka.core.Tag;import weka.core.TechnicalInformation;import weka.core.TechnicalInformationHandler;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;import weka.core.TechnicalInformation.Field;import weka.core.TechnicalInformation.Type;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;import java.util.Arrays;import java.util.BitSet;import java.util.Enumeration;import java.util.Hashtable;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Class for building and using a simple decision table majority classifier.<br/> * <br/> * For more information see: <br/> * <br/> * Ron Kohavi: The Power of Decision Tables. In: 8th European Conference on Machine Learning, 174-189, 1995. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @inproceedings{Kohavi1995, * author = {Ron Kohavi}, * booktitle = {8th European Conference on Machine Learning}, * pages = {174-189}, * publisher = {Springer}, * title = {The Power of Decision Tables}, * year = {1995} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -S <search method specification> * Full class name of search method, followed * by its options. * eg: "weka.attributeSelection.BestFirst -D 1" * (default weka.attributeSelection.BestFirst)</pre> * * <pre> -X <number of folds> * Use cross validation to evaluate features. * Use number of folds = 1 for leave one out CV. * (Default = leave one out CV)</pre> * * <pre> -E <acc | rmse | mae | auc> * Performance evaluation measure to use for selecting attributes. * (Default = accuracy for discrete class and rmse for numeric class)</pre> * * <pre> -I * Use nearest neighbour instead of global table majority.</pre> * * <pre> -R * Display decision table rules. * </pre> * * <pre> * Options specific to search method weka.attributeSelection.BestFirst: * </pre> * * <pre> -P <start set> * Specify a starting set of attributes. * Eg. 1,3,5-7.</pre> * * <pre> -D <0 = backward | 1 = forward | 2 = bi-directional> * Direction of search. (default = 1).</pre> * * <pre> -N <num> * Number of non-improving nodes to * consider before terminating search.</pre> * * <pre> -S <num> * Size of lookup cache for evaluated subsets. * Expressed as a multiple of the number of * attributes in the data set. (default = 1)</pre> * <!-- options-end --> * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 1.44 $ */public class DecisionTable extends Classifier implements OptionHandler, WeightedInstancesHandler, AdditionalMeasureProducer, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = 2888557078165701326L; /** The hashtable used to hold training instances */ protected Hashtable m_entries; /** The class priors to use when there is no match in the table */ protected double [] m_classPriorCounts; protected double [] m_classPriors; /** Holds the final feature set */ protected int [] m_decisionFeatures; /** Discretization filter */ protected Filter m_disTransform; /** Filter used to remove columns discarded by feature selection */ protected Remove m_delTransform; /** IB1 used to classify non matching instances rather than majority class */ protected IBk m_ibk; /** Holds the original training instances */ protected Instances m_theInstances; /** Holds the final feature selected set of instances */ protected Instances m_dtInstances; /** The number of attributes in the dataset */ protected int m_numAttributes; /** The number of instances in the dataset */ private int m_numInstances; /** Class is nominal */ protected boolean m_classIsNominal; /** Use the IBk classifier rather than majority class */ protected boolean m_useIBk; /** Display Rules */ protected boolean m_displayRules; /** Number of folds for cross validating feature sets */ private int m_CVFolds; /** Random numbers for use in cross validation */ private Random m_rr; /** Holds the majority class */ protected double m_majority; /** The search method to use */ protected ASSearch m_search = new BestFirst(); /** Our own internal evaluator */ protected SubsetEvaluator m_evaluator; /** The evaluation object used to evaluate subsets */ protected Evaluation m_evaluation; /** default is accuracy for discrete class and RMSE for numeric class */ public static final int EVAL_DEFAULT = 1; public static final int EVAL_ACCURACY = 2; public static final int EVAL_RMSE = 3; public static final int EVAL_MAE = 4; public static final int EVAL_AUC = 5; public static final Tag [] TAGS_EVALUATION = { new Tag(EVAL_DEFAULT, "Default: accuracy (discrete class); RMSE (numeric class)"), new Tag(EVAL_ACCURACY, "Accuracy (discrete class only"), new Tag(EVAL_RMSE, "RMSE (of the class probabilities for discrete class)"), new Tag(EVAL_MAE, "MAE (of the class probabilities for discrete class)"), new Tag(EVAL_AUC, "AUC (area under the ROC curve - discrete class only)") }; protected int m_evaluationMeasure = EVAL_DEFAULT; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for building and using a simple decision table majority " + "classifier.\n\n" + "For more information see: \n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Ron Kohavi"); result.setValue(Field.TITLE, "The Power of Decision Tables"); result.setValue(Field.BOOKTITLE, "8th European Conference on Machine Learning"); result.setValue(Field.YEAR, "1995"); result.setValue(Field.PAGES, "174-189"); result.setValue(Field.PUBLISHER, "Springer"); return result; } /** * Inserts an instance into the hash table * * @param inst instance to be inserted * @param instA to create the hash key from * @throws Exception if the instance can't be inserted */ private void insertIntoTable(Instance inst, double [] instA) throws Exception { double [] tempClassDist2; double [] newDist; DecisionTableHashKey thekey; if (instA != null) { thekey = new DecisionTableHashKey(instA); } else { thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false); } // see if this one is already in the table tempClassDist2 = (double []) m_entries.get(thekey); if (tempClassDist2 == null) { if (m_classIsNominal) { newDist = new double [m_theInstances.classAttribute().numValues()]; //Leplace estimation for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) { newDist[i] = 1.0; } newDist[(int)inst.classValue()] = inst.weight(); // add to the table m_entries.put(thekey, newDist); } else { newDist = new double [2]; newDist[0] = inst.classValue() * inst.weight(); newDist[1] = inst.weight(); // add to the table m_entries.put(thekey, newDist); } } else { // update the distribution for this instance if (m_classIsNominal) { tempClassDist2[(int)inst.classValue()]+=inst.weight(); // update the table m_entries.put(thekey, tempClassDist2); } else { tempClassDist2[0] += (inst.classValue() * inst.weight()); tempClassDist2[1] += inst.weight(); // update the table m_entries.put(thekey, tempClassDist2); } } } /** * Classifies an instance for internal leave one out cross validation * of feature sets * * @param instance instance to be "left out" and classified * @param instA feature values of the selected features for the instance * @return the classification of the instance * @throws Exception if something goes wrong */ double evaluateInstanceLeaveOneOut(Instance instance, double [] instA) throws Exception { DecisionTableHashKey thekey; double [] tempDist; double [] normDist; thekey = new DecisionTableHashKey(instA); if (m_classIsNominal) { // if this one is not in the table if ((tempDist = (double [])m_entries.get(thekey)) == null) { throw new Error("This should never happen!"); } else { normDist = new double [tempDist.length]; System.arraycopy(tempDist,0,normDist,0,tempDist.length); normDist[(int)instance.classValue()] -= instance.weight(); // update the table // first check to see if the class counts are all zero now boolean ok = false; for (int i=0;i<normDist.length;i++) { if (Utils.gr(normDist[i],1.0)) { ok = true; break; } }// downdate the class prior counts m_classPriorCounts[(int)instance.classValue()] -= instance.weight(); double [] classPriors = m_classPriorCounts.clone(); Utils.normalize(classPriors); if (!ok) { // majority class normDist = classPriors; } m_classPriorCounts[(int)instance.classValue()] += instance.weight(); //if (ok) { Utils.normalize(normDist); if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance); } else { m_evaluation.evaluateModelOnce(normDist, instance); } return Utils.maxIndex(normDist); /*} else { normDist = new double [normDist.length]; normDist[(int)m_majority] = 1.0; if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance); } else { m_evaluation.evaluateModelOnce(normDist, instance); } return m_majority; } */ } // return Utils.maxIndex(tempDist); } else { // see if this one is already in the table if ((tempDist = (double[])m_entries.get(thekey)) != null) { normDist = new double [tempDist.length]; System.arraycopy(tempDist,0,normDist,0,tempDist.length); normDist[0] -= (instance.classValue() * instance.weight()); normDist[1] -= instance.weight(); if (Utils.eq(normDist[1],0.0)) { double [] temp = new double[1]; temp[0] = m_majority; m_evaluation.evaluateModelOnce(temp, instance); return m_majority; } else { double [] temp = new double[1]; temp[0] = normDist[0] / normDist[1]; m_evaluation.evaluateModelOnce(temp, instance); return temp[0]; } } else { throw new Error("This should never happen!"); } } // shouldn't get here // return 0.0; } /** * Calculates the accuracy on a test fold for internal cross validation * of feature sets * * @param fold set of instances to be "left out" and classified * @param fs currently selected feature set * @return the accuracy for the fold * @throws Exception if something goes wrong */ double evaluateFoldCV(Instances fold, int [] fs) throws Exception { int i; int ruleCount = 0; int numFold = fold.numInstances(); int numCl = m_theInstances.classAttribute().numValues(); double [][] class_distribs = new double [numFold][numCl]; double [] instA = new double [fs.length]; double [] normDist; DecisionTableHashKey thekey; double acc = 0.0; int classI = m_theInstances.classIndex(); Instance inst; if (m_classIsNominal) { normDist = new double [numCl]; } else { normDist = new double [2]; } // first *remove* instances for (i=0;i<numFold;i++) { inst = fold.instance(i); for (int j=0;j<fs.length;j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else{ instA[j] = inst.value(fs[j]); } } thekey = new DecisionTableHashKey(instA); if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) { throw new Error("This should never happen!"); } else { if (m_classIsNominal) { class_distribs[i][(int)inst.classValue()] -= inst.weight(); } else { class_distribs[i][0] -= (inst.classValue() * inst.weight()); class_distribs[i][1] -= inst.weight(); } ruleCount++; } m_classPriorCounts[(int)inst.classValue()] -= inst.weight(); } double [] classPriors = m_classPriorCounts.clone(); Utils.normalize(classPriors); // now classify instances
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?