⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 simplecart.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SimpleCart.java * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand * */package weka.classifiers.trees;import weka.classifiers.Evaluation;import weka.classifiers.RandomizableClassifier;import weka.core.AdditionalMeasureProducer;import weka.core.Attribute;import weka.core.Capabilities;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.TechnicalInformation;import weka.core.TechnicalInformationHandler;import weka.core.Utils;import weka.core.Capabilities.Capability;import weka.core.TechnicalInformation.Field;import weka.core.TechnicalInformation.Type;import weka.core.matrix.Matrix;import java.util.Arrays;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Class implementing minimal cost-complexity pruning.<br/> * Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.<br/> * <br/> * For more information, see:<br/> * <br/> * Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California. * <p/> <!-- globalinfo-end -->	 * <!-- technical-bibtex-start --> * BibTeX: * <pre> * &#64;book{Breiman1984, *    address = {Belmont, California}, *    author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone}, *    publisher = {Wadsworth International Group}, *    title = {Classification and Regression Trees}, *    year = {1984} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> *  * <pre> -S &lt;num&gt; *  Random number seed. *  (default 1)</pre> *  * <pre> -D *  If set, classifier is run in debug mode and *  may output additional info to the console</pre> *  * <pre> -M &lt;min no&gt; *  The minimal number of instances at the terminal nodes. *  (default 2)</pre> *  * <pre> -N &lt;num folds&gt; *  The number of folds used in the minimal cost-complexity pruning. *  (default 5)</pre> *  * <pre> -U *  Don't use the minimal cost-complexity pruning. *  (default yes).</pre> *  * <pre> -H *  Don't use the heuristic method for binary split. *  (default true).</pre> *  * <pre> -A *  Use 1 SE rule to make pruning decision. *  (default no).</pre> *  * <pre> -C *  Percentage of training data size (0-1]. *  (default 1).</pre> *  <!-- options-end --> * * @author Haijian Shi (hs69@cs.waikato.ac.nz) * @version $Revision: 1.2 $ */public class SimpleCart  extends RandomizableClassifier  implements AdditionalMeasureProducer, TechnicalInformationHandler {  /** For serialization.	 */  private static final long serialVersionUID = 4154189200352566053L;  /** Training data.  */  protected Instances m_train;  /** Successor nodes. */  protected SimpleCart[] m_Successors;  /** Attribute used to split data. */  protected Attribute m_Attribute;  /** Split point for a numeric attribute. */  protected double m_SplitValue;  /** Split subset used to split data for nominal attributes. */  protected String m_SplitString;  /** Class value if the node is leaf. */  protected double m_ClassValue;  /** Class attriubte of data. */  protected Attribute m_ClassAttribute;  /** Minimum number of instances in at the terminal nodes. */  protected double m_minNumObj = 2;  /** Number of folds for minimal cost-complexity pruning. */  protected int m_numFoldsPruning = 5;  /** Alpha-value (for pruning) at the node. */  protected double m_Alpha;  /** Number of training examples misclassified by the model (subtree rooted). */  protected double m_numIncorrectModel;  /** Number of training examples misclassified by the model (subtree not rooted). */  protected double m_numIncorrectTree;  /** Indicate if the node is a leaf node. */  protected boolean m_isLeaf;  /** If use minimal cost-compexity pruning. */  protected boolean m_Prune = true;  /** Total number of instances used to build the classifier. */  protected int m_totalTrainInstances;  /** Proportion for each branch. */  protected double[] m_Props;  /** Class probabilities. */  protected double[] m_ClassProbs = null;  /** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */  protected double[] m_Distribution;  /** If use huristic search for nominal attributes in multi-class problems (default true). */  protected boolean m_Heuristic = true;  /** If use the 1SE rule to make final decision tree. */  protected boolean m_UseOneSE = false;  /** Training data size. */  protected double m_SizePer = 1;  /**   * Return a description suitable for displaying in the explorer/experimenter.   *    * @return 		a description suitable for displaying in the    * 			explorer/experimenter   */  public String globalInfo() {    return          "Class implementing minimal cost-complexity pruning.\n"      + "Note when dealing with missing values, use \"fractional "      + "instances\" method instead of surrogate split method.\n\n"      + "For more information, see:\n\n"      + getTechnicalInformation().toString();  }  /**   * Returns an instance of a TechnicalInformation object, containing    * detailed information about the technical background of this class,   * e.g., paper reference or book this class is based on.   *    * @return 		the technical information about this class   */  public TechnicalInformation getTechnicalInformation() {    TechnicalInformation 	result;        result = new TechnicalInformation(Type.BOOK);    result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");    result.setValue(Field.YEAR, "1984");    result.setValue(Field.TITLE, "Classification and Regression Trees");    result.setValue(Field.PUBLISHER, "Wadsworth International Group");    result.setValue(Field.ADDRESS, "Belmont, California");        return result;  }  /**   * Returns default capabilities of the classifier.   *    * @return 		the capabilities of this classifier   */  public Capabilities getCapabilities() {    Capabilities result = super.getCapabilities();    // attributes    result.enable(Capability.NOMINAL_ATTRIBUTES);    result.enable(Capability.NUMERIC_ATTRIBUTES);    result.enable(Capability.MISSING_VALUES);    // class    result.enable(Capability.NOMINAL_CLASS);    return result;  }  /**   * Build the classifier.   *    * @param data 	the training instances   * @throws Exception 	if something goes wrong   */  public void buildClassifier(Instances data) throws Exception {    getCapabilities().testWithFail(data);    data = new Instances(data);            data.deleteWithMissingClass();    // unpruned CART decision tree    if (!m_Prune) {      // calculate sorted indices and weights, and compute initial class counts.      int[][] sortedIndices = new int[data.numAttributes()][0];      double[][] weights = new double[data.numAttributes()][0];      double[] classProbs = new double[data.numClasses()];      double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);      makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,	  totalWeight,m_minNumObj, m_Heuristic);      return;    }    Random random = new Random(m_Seed);    Instances cvData = new Instances(data);    cvData.randomize(random);    cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1);    cvData.stratify(m_numFoldsPruning);    double[][] alphas = new double[m_numFoldsPruning][];    double[][] errors = new double[m_numFoldsPruning][];    // calculate errors and alphas for each fold    for (int i = 0; i < m_numFoldsPruning; i++) {      //for every fold, grow tree on training set and fix error on test set.      Instances train = cvData.trainCV(m_numFoldsPruning, i);      Instances test = cvData.testCV(m_numFoldsPruning, i);      // calculate sorted indices and weights, and compute initial class counts for each fold      int[][] sortedIndices = new int[train.numAttributes()][0];      double[][] weights = new double[train.numAttributes()][0];      double[] classProbs = new double[train.numClasses()];      double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs);      makeTree(train, train.numInstances(),sortedIndices,weights,classProbs,	  totalWeight,m_minNumObj, m_Heuristic);      int numNodes = numInnerNodes();      alphas[i] = new double[numNodes + 2];      errors[i] = new double[numNodes + 2];      // prune back and log alpha-values and errors on test set      prune(alphas[i], errors[i], test);    }    // calculate sorted indices and weights, and compute initial class counts on all training instances    int[][] sortedIndices = new int[data.numAttributes()][0];    double[][] weights = new double[data.numAttributes()][0];    double[] classProbs = new double[data.numClasses()];    double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);    //build tree using all the data    makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,	totalWeight,m_minNumObj, m_Heuristic);    int numNodes = numInnerNodes();    double[] treeAlphas = new double[numNodes + 2];    // prune back and log alpha-values    int iterations = prune(treeAlphas, null, null);    double[] treeErrors = new double[numNodes + 2];    // for each pruned subtree, find the cross-validated error    for (int i = 0; i <= iterations; i++){      //compute midpoint alphas      double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);      double error = 0;      for (int k = 0; k < m_numFoldsPruning; k++) {	int l = 0;	while (alphas[k][l] <= alpha) l++;	error += errors[k][l - 1];      }      treeErrors[i] = error/m_numFoldsPruning;    }    // find best alpha    int best = -1;    double bestError = Double.MAX_VALUE;    for (int i = iterations; i >= 0; i--) {      if (treeErrors[i] < bestError) {	bestError = treeErrors[i];	best = i;      }    }    // 1 SE rule to choose expansion    if (m_UseOneSE) {      double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances()));      for (int i = iterations; i >= 0; i--) {	if (treeErrors[i] <= bestError+oneSE) {	  best = i;	  break;	}      }    }    double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);    //"unprune" final tree (faster than regrowing it)    unprune();    prune(bestAlpha);          }  /**   * Make binary decision tree recursively.   *    * @param data 		the training instances   * @param totalInstances 	total number of instances   * @param sortedIndices 	sorted indices of the instances   * @param weights 		weights of the instances   * @param classProbs 		class probabilities   * @param totalWeight 	total weight of instances   * @param minNumObj 		minimal number of instances at leaf nodes   * @param useHeuristic 	if use heuristic search for nominal attributes in multi-class problem   * @throws Exception 		if something goes wrong   */  protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices,      double[][] weights, double[] classProbs, double totalWeight, double minNumObj,      boolean useHeuristic) throws Exception{    // if no instances have reached this node (normally won't happen)    if (totalWeight == 0){      m_Attribute = null;      m_ClassValue = Instance.missingValue();      m_Distribution = new double[data.numClasses()];      return;    }    m_totalTrainInstances = totalInstances;    m_isLeaf = true;    m_ClassProbs = new double[classProbs.length];    m_Distribution = new double[classProbs.length];    System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);    System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);    if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);    // Compute class distributions and value of splitting    // criterion for each attribute    double[][][] dists = new double[data.numAttributes()][0][0];    double[][] props = new double[data.numAttributes()][0];    double[][] totalSubsetWeights = new double[data.numAttributes()][2];    double[] splits = new double[data.numAttributes()];    String[] splitString = new String[data.numAttributes()];    double[] giniGains = new double[data.numAttributes()];    // for each attribute find split information    for (int i = 0; i < data.numAttributes(); i++) {      Attribute att = data.attribute(i);      if (i==data.classIndex()) continue;      if (att.isNumeric()) {	// numeric attribute	splits[i] = numericDistribution(props, dists, att, sortedIndices[i],	    weights[i], totalSubsetWeights, giniGains, data);      } else {	// nominal attribute	splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],	    weights[i], totalSubsetWeights, giniGains, data, useHeuristic);      }    }    // Find best attribute (split with maximum Gini gain)    int attIndex = Utils.maxIndex(giniGains);    m_Attribute = data.attribute(attIndex);    m_train = new Instances(data, sortedIndices[attIndex].length);    for (int i=0; i<sortedIndices[attIndex].length; i++) {      Instance inst = data.instance(sortedIndices[attIndex][i]);      Instance instCopy = (Instance)inst.copy();      instCopy.setWeight(weights[attIndex][i]);      m_train.add(instCopy);    }    // Check if node does not contain enough instances, or if it can not be split,    // or if it is pure. If does, make leaf.    if (totalWeight < 2 * minNumObj || giniGains[attIndex]==0 ||	props[attIndex][0]==0 || props[attIndex][1]==0) {      makeLeaf(data);    }    else {                  m_Props = props[attIndex];      int[][][] subsetIndices = new int[2][data.numAttributes()][0];      double[][][] subsetWeights = new double[2][data.numAttributes()][0];      // numeric split      if (m_Attribute.isNumeric()) m_SplitValue = splits[attIndex];      // nominal split      else m_SplitString = splitString[attIndex];      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,	  m_SplitString, sortedIndices, weights, data);      // If split of the node results in a node with less than minimal number of isntances,       // make the node leaf node.      if (subsetIndices[0][attIndex].length<minNumObj ||	  subsetIndices[1][attIndex].length<minNumObj) {	makeLeaf(data);	return;      }      // Otherwise, split the node.      m_isLeaf = false;      m_Successors = new SimpleCart[2];      for (int i = 0; i < 2; i++) {	m_Successors[i] = new SimpleCart();	m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i],	    subsetWeights[i],dists[attIndex][i], totalSubsetWeights[attIndex][i],	    minNumObj, useHeuristic);      }    }  }  /**   * Prunes the original tree using the CART pruning scheme, given a    * cost-complexity parameter alpha.   *    * @param alpha 	the cost-complexity parameter   * @throws Exception 	if something goes wrong   */  public void prune(double alpha) throws Exception {    Vector nodeList;    // determine training error of pruned subtrees (both with and without replacing a subtree),    // and calculate alpha-values from them    modelErrors();    treeErrors();    calculateAlphas();    // get list of all inner nodes in the tree    nodeList = getInnerNodes();    boolean prune = (nodeList.size() > 0);    double preAlpha = Double.MAX_VALUE;    while (prune) {      // select node with minimum alpha      SimpleCart nodeToPrune = nodeToPrune(nodeList);      // want to prune if its alpha is smaller than alpha      if (nodeToPrune.m_Alpha > alpha) {	break;      }      nodeToPrune.makeLeaf(nodeToPrune.m_train);      // normally would not happen      if (nodeToPrune.m_Alpha==preAlpha) {	nodeToPrune.makeLeaf(nodeToPrune.m_train);	treeErrors();	calculateAlphas();

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -