📄 fat.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * FAT.java * Copyright (C) 1999 Eibe Frank * Modified by Prem Melville * */package weka.classifiers.trees.j48;import java.util.*;import weka.core.*;import weka.classifiers.*;/** * Feature Aquisition Trees based on C4.5 * * Class for generating an unpruned or a pruned C4.5 decision tree. * For more information, see<p> * * Ross Quinlan (1993). <i>C4.5: Programs for Machine Learning</i>, * Morgan Kaufmann Publishers, San Mateo, CA. </p> * * Valid options are: <p> * * -U <br> * Use unpruned tree.<p> * * -C confidence <br> * Set confidence threshold for pruning. (Default: 0.25) <p> * * -M number <br> * Set minimum number of instances per leaf. (Default: 2) <p> * * -R <br> * Use reduced error pruning. No subtree raising is performed. <p> * * -N number <br> * Set number of folds for reduced error pruning. One fold is * used as the pruning set. (Default: 3) <p> * * -B <br> * Use binary splits for nominal attributes. <p> * * -S <br> * Don't perform subtree raising. <p> * * -L <br> * Do not clean up after the tree has been built. <p> * * -A <br> * If set, Laplace smoothing is used for predicted probabilites. <p> * * * @author Prem Melville (melville@cs.utexas.edu) * @version $Revision: 1.7 $ */public class FAT extends EnsembleClassifier implements OptionHandler, Drawable, Matchable, Sourcable, WeightedInstancesHandler, Summarizable, AdditionalMeasureProducer, ActiveFeatureAcquirer { //=============== BEGIN EDIT melville =============== /** To use weighted sampling instead of deterministic selection of instance */ protected boolean m_UseSampling = false; /** Random number generator - for sampling */ protected Random m_Random = new Random(0); protected boolean m_Debug = true; /** Enumeration of selective sampling schemes */ static final int LABEL_MARGIN = 0, RANDOM = 1, ABS_LABEL_MARGIN = 2, UNLABELED_MARGIN = 3, ENTROPY = 4, MAX_CERTAINTY = 5, HYBRID = 6, RANDOM_HYBRID = 7; /* LABEL_MARGIN = INCORRECT_CERTAINTY UNLABELED_MARGIN = UNCERTAINITY HYBRID = ASYMMETRIC CERTAINTY RANDOM_HYBRID = ERROR SAMPLING */ /** The selective sampling scheme to use. */ protected int m_SelectionScheme = LABEL_MARGIN; //=============== END EDIT melville =============== // To maintain the same version number after adding m_ClassAttribute static final long serialVersionUID = -217733168393644444L; /** The decision tree */ private ClassifierTree m_root; /** Unpruned tree? */ private boolean m_unpruned = false; /** Confidence level */ private float m_CF = 0.25f; /** Minimum number of instances */ private int m_minNumObj = 2; /** Determines whether probabilities are smoothed using Laplace correction when predictions are generated */ private boolean m_useLaplace = false; /** Use reduced error pruning? */ private boolean m_reducedErrorPruning = false; /** Number of folds for reduced error pruning. */ private int m_numFolds = 3; /** Binary splits on nominal attributes? */ private boolean m_binarySplits = false; /** Subtree raising to be performed? */ private boolean m_subtreeRaising = true; /** Cleanup after the tree has been built. */ boolean m_noCleanup = false; /** * Given a set of unlabeled examples, select a specified number of examples to be labeled. * @param activePool pool of unlabeled examples * @param num number of examples to selcted for labeling * @exception Exception if selective sampling fails */ public int [] selectInstancesForFeatures(Instances activePool,int num) throws Exception{ //Make a list of pairs of indices and the corresponding measure of informativenes of examples //Sort this in the order of informativeness and return the list of num indices int poolSize = activePool.numInstances(); Pair []pairs = new Pair[poolSize]; for(int i=0; i<poolSize; i++){ pairs[i] = new Pair(i,calculateScore(activePool.instance(i))); } //sort in AScending order Arrays.sort(pairs, new Comparator() { public int compare(Object o1, Object o2) { double diff = ((Pair)o2).second - ((Pair)o1).second; return(diff < 0 ? 1 : diff > 0 ? -1 : 0); } }); int []selected = new int[num]; if(m_UseSampling){ if(pairs[0].first < 0) throw new Exception("Sampling with negative scores has not been defined."); //If sampling is being used then convert scores to //probabilities (inversely proportional) Vector v = new Vector(poolSize); double sum = 0; for(int i=0; i<poolSize; i++){ if(pairs[i].second == 0){ pairs[i].second = Double.MAX_VALUE/poolSize; //Account for probability values of 0 - to avoid divide-by-zero errors //Divide by probs.length to make sure normalizing works properly }else{ pairs[i].second = 1.0 / pairs[i].second; } //Convert pairs from array to vector v.add(pairs[i]); } /* For j=1 to n * Create a cdf * Randomly pick instance based on cdf * Note index and remove element */ for(int j=0; j<num; j++){ sum = 0; for(int i=0; i<v.size(); i++) sum += ((Pair) v.get(i)).second; //normalize if (Double.isNaN(sum)) { throw new IllegalArgumentException("Can't normalize array. Sum is NaN."); } if (sum == 0) { throw new IllegalArgumentException("Can't normalize array. Sum is zero."); } for(int i=0; i<v.size(); i++) ((Pair) v.get(i)).second /= sum; //create a cdf double []cdf = new double[v.size()]; cdf[0] = ((Pair) v.get(0)).second; for(int i=1; i<v.size(); i++) cdf[i] = ((Pair) v.get(i)).second + cdf[i-1]; double rnd = m_Random.nextDouble(); int index = 0; while(index < cdf.length && rnd > cdf[index]){ index++; } selected[j] = (int) (((Pair) v.get(index)).first); v.remove(index); } assert v.size()+num==poolSize : v.size()+" + "+num+" != "+poolSize+"\n"; if(m_Debug) { System.out.println("Sampled list:"); for(int j=0; j<num; j++){ System.out.println("\t"+selected[j]); } } }else{//take the first num if(m_Debug) System.out.println("Sorted list:"); for(int j=0; j<num; j++){ if(m_Debug) System.out.println("\t"+pairs[j].second+"\t"+pairs[j].first); selected[j] = (int) pairs[j].first; } } return selected; } /** * Calculate the feature acquisition score for * given examples depending on the chosen selection scheme. * @param instance local instances from the current pool * @return score * @exception Exception if score could not be calculated properly */ protected double calculateScore(Instance instance) throws Exception{ double score; switch(m_SelectionScheme){ case RANDOM_HYBRID: score = calculateRandomHybridScore(instance); break; case HYBRID: score = calculateHybridScore(instance); break; case ENTROPY: score = -1.0*calculateEntropy(instance); break; case UNLABELED_MARGIN: score = calculateMargin(instance); break; case MAX_CERTAINTY: score = -1.0*calculateMargin(instance); break; case ABS_LABEL_MARGIN: score = Math.abs(calculateLabeledInstanceMargin(instance)); break; case LABEL_MARGIN: score = calculateLabeledInstanceMargin(instance); break; case RANDOM: score = 0;//all instances are equal break; default: score = 0;//all instances are equal break; } return score; } /** * Generates the classifier. * * @exception Exception if classifier can't be built successfully */ public void buildClassifier(Instances instances) throws Exception { ModelSelection modSelection; if (m_binarySplits) modSelection = new BinC45ModelSelection(m_minNumObj, instances); else modSelection = new C45ModelSelection(m_minNumObj, instances); if (!m_reducedErrorPruning) m_root = new C45PruneableClassifierTree(modSelection, !m_unpruned, m_CF, m_subtreeRaising, !m_noCleanup); else m_root = new PruneableClassifierTree(modSelection, !m_unpruned, m_numFolds, !m_noCleanup); m_root.buildClassifier(instances); if (m_binarySplits) { ((BinC45ModelSelection)modSelection).cleanup(); } else { ((C45ModelSelection)modSelection).cleanup(); } } //=============== BEGIN EDIT melville =============== /** Returns class predictions of each ensemble member */ public double []getEnsemblePredictions(Instance instance) throws Exception{ double preds[] = new double[1]; preds[0] = classifyInstance(instance); return preds; } /** * Returns vote weights of ensemble members. * * @return vote weights of ensemble members */ public double []getEnsembleWts(){ double []ensembleWts = {1.0}; return ensembleWts; } /** Returns size of ensemble */ public double getEnsembleSize(){ return 1.0; } //=============== END EDIT melville =============== /** * Classifies an instance. * * @exception Exception if instance can't be classified successfully */ public double classifyInstance(Instance instance) throws Exception { return m_root.classifyInstance(instance); } /** * Returns class probabilities for an instance. * * @exception Exception if distribution can't be computed successfully */ public final double [] distributionForInstance(Instance instance) throws Exception { return m_root.distributionForInstance(instance, m_useLaplace); } /** * Returns the type of graph this classifier * represents. * @return Drawable.TREE */ // public int graphType() { // return Drawable.TREE; // } /** * Returns graph describing the tree. * * @exception Exception if graph can't be computed */ public String graph() throws Exception { return m_root.graph(); } /** * Returns tree in prefix order. * * @exception Exception if something goes wrong */ public String prefix() throws Exception { return m_root.prefix(); } /** * Returns tree as an if-then statement. * * @return the tree as a Java if-then type statement * @exception Exception if something goes wrong */ public String toSource(String className) throws Exception { StringBuffer [] source = m_root.toSource(className); return "class " + className + " {\n\n" +" public static double classify(Object [] i)\n" +" throws Exception {\n\n" +" double p = Double.NaN;\n" + source[0] // Assignment code +" return p;\n" +" }\n" + source[1] // Support code +"}\n"; } /** * Returns an enumeration describing the available options. * * Valid options are: <p> * * -U <br> * Use unpruned tree.<p> * * -C confidence <br> * Set confidence threshold for pruning. (Default: 0.25) <p> * * -M number <br> * Set minimum number of instances per leaf. (Default: 2) <p> * * -R <br> * Use reduced error pruning. No subtree raising is performed. <p> * * -N number <br> * Set number of folds for reduced error pruning. One fold is * used as the pruning set. (Default: 3) <p> * * -B <br> * Use binary splits for nominal attributes. <p> * * -S <br> * Don't perform subtree raising. <p> * * -L <br> * Do not clean up after the tree has been built. * * -A <br> * If set, Laplace smoothing is used for predicted probabilites. <p> * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(9); newVector. addElement(new Option("\tUse unpruned tree.", "U", 0, "-U")); newVector. addElement(new Option("\tSet confidence threshold for pruning.\n" + "\t(default 0.25)", "C", 1, "-C <pruning confidence>")); newVector. addElement(new Option("\tSet minimum number of instances per leaf.\n" + "\t(default 2)", "M", 1, "-M <minimum number of instances>")); newVector. addElement(new Option("\tUse reduced error pruning.", "R", 0, "-R")); newVector. addElement(new Option("\tSet number of folds for reduced error\n" + "\tpruning. One fold is used as pruning set.\n" + "\t(default 3)", "N", 1, "-N <number of folds>"));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -