⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 reptree.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    REPTree.java *    Copyright (C) 1999 Eibe Frank * */package weka.classifiers.trees;import weka.classifiers.*;import weka.core.*;import java.util.*;import java.io.*;/** * Fast decision tree learner. Builds a decision/regression tree using * information gain/variance and prunes it using reduced-error pruning. * Only sorts values for numeric attributes once. Missing values are * dealt with by splitting the corresponding instances into pieces * (i.e. as in C4.5). * * Valid options are: <p> * * -M number <br> * Set minimum number of instances per leaf (default 2). <p> * * -V number <br> * Set minimum numeric class variance proportion of train variance for * split (default 1e-3). <p> *			     * -N number <br> * Number of folds for reduced error pruning (default 3). <p> * * -S number <br>  * Seed for random data shuffling (default 1). <p> * * -P <br> * No pruning. <p> * * -D <br> * Maximum tree depth (default -1, no maximum). <p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $  */public class REPTree extends DistributionClassifier   implements OptionHandler, WeightedInstancesHandler, Drawable, 	     AdditionalMeasureProducer {  /** An inner class for building and storing the tree structure */  protected class Tree implements Serializable {        /** The header information (for printing the tree). */    protected Instances m_Info = null;    /** The subtrees of this tree. */     protected Tree[] m_Successors;        /** The attribute to split on. */    protected int m_Attribute = -1;        /** The split point. */    protected double m_SplitPoint = Double.NaN;        /** The proportions of training instances going down each branch. */    protected double[] m_Prop = null;    /** Class probabilities from the training data in the nominal case. 	Holds the mean in the numeric case. */    protected double[] m_ClassProbs = null;        /** The (unnormalized) class distribution in the nominal	case. Holds the sum of squared errors and the weight 	in the numeric case. */    protected double[] m_Distribution = null;        /** Class distribution of hold-out set at node in the nominal case. 	Straight sum of weights in the numeric case (i.e. array has	only one element. */    protected double[] m_HoldOutDist = null;        /** The hold-out error of the node. The number of miss-classified	instances in the nominal case, the sum of squared errors in the 	numeric case. */    protected double m_HoldOutError = 0;      /**     * Computes class distribution of an instance using the tree.     */    protected double[] distributionForInstance(Instance instance)       throws Exception {          double[] returnedDist = null;            if (m_Attribute > -1) {		// Node is not a leaf	if (instance.isMissing(m_Attribute)) {	  	  // Value is missing	  returnedDist = new double[m_Info.numClasses()];	  // Split instance up	  for (int i = 0; i < m_Successors.length; i++) {	    double[] help = 	      m_Successors[i].distributionForInstance(instance);	    if (help != null) {	      for (int j = 0; j < help.length; j++) {		returnedDist[j] += m_Prop[i] * help[j];	      }	    }	  }	} else if (m_Info.attribute(m_Attribute).isNominal()) {	  	  // For nominal attributes	  returnedDist =  m_Successors[(int)instance.value(m_Attribute)].	    distributionForInstance(instance);	} else {	  	  // For numeric attributes	  if (instance.value(m_Attribute) < m_SplitPoint) {	    returnedDist = 	      m_Successors[0].distributionForInstance(instance);	  } else {	    returnedDist = 	      m_Successors[1].distributionForInstance(instance);	  }	}      }      if ((m_Attribute == -1) || (returnedDist == null)) {		// Node is a leaf or successor is empty	return m_ClassProbs;      } else {	return returnedDist;      }    }      /**     * Outputs one node for graph.     */    protected int toGraph(StringBuffer text, int num,			Tree parent) throws Exception {            num++;      if (m_Attribute == -1) {	text.append("N" + Integer.toHexString(Tree.this.hashCode()) +		    " [label=\"" + num + leafString(parent) +"\"" +		    "shape=box]\n");      } else {	text.append("N" + Integer.toHexString(Tree.this.hashCode()) +		    " [label=\"" + num + ": " + 		    m_Info.attribute(m_Attribute).name() + 		    "\"]\n");	for (int i = 0; i < m_Successors.length; i++) {	  text.append("N" + Integer.toHexString(Tree.this.hashCode()) 		      + "->" + 		      "N" + 		      Integer.toHexString(m_Successors[i].hashCode())  +		      " [label=\"");	  if (m_Info.attribute(m_Attribute).isNumeric()) {	    if (i == 0) {	      text.append(" < " +			  Utils.doubleToString(m_SplitPoint, 2));	    } else {	      text.append(" >= " +			  Utils.doubleToString(m_SplitPoint, 2));	    }	  } else {	    text.append(" = " + m_Info.attribute(m_Attribute).value(i));	  }	  text.append("\"]\n");	  num = m_Successors[i].toGraph(text, num, this);	}      }            return num;    }    /**     * Outputs description of a leaf node.     */    protected String leafString(Tree parent) throws Exception {          if (m_Info.classAttribute().isNumeric()) {	double classMean;	if (m_ClassProbs == null) {	  classMean = parent.m_ClassProbs[0];	} else {	  classMean = m_ClassProbs[0];	}	StringBuffer buffer = new StringBuffer();	buffer.append(" : " + Utils.doubleToString(classMean, 2));	double avgError = 0;	if (m_Distribution[1] > 0) {	  avgError = m_Distribution[0] / m_Distribution[1];	}	buffer.append(" (" +		      Utils.doubleToString(m_Distribution[1], 2) + "/" +		      Utils.doubleToString(avgError, 2) 		      + ")");	avgError = 0;	if (m_HoldOutDist[0] > 0) {	  avgError = m_HoldOutError / m_HoldOutDist[0];	}	buffer.append(" [" +		      Utils.doubleToString(m_HoldOutDist[0], 2) + "/" +		      Utils.doubleToString(avgError, 2) 		      + "]");	return buffer.toString();      } else { 	int maxIndex;	if (m_ClassProbs == null) {	  maxIndex = Utils.maxIndex(parent.m_ClassProbs);	} else {	  maxIndex = Utils.maxIndex(m_ClassProbs);	}	return " : " + m_Info.classAttribute().value(maxIndex) + 	  " (" + Utils.doubleToString(Utils.sum(m_Distribution), 2) + 	  "/" + 	  Utils.doubleToString((Utils.sum(m_Distribution) - 				m_Distribution[maxIndex]), 2) + ")" +	  " [" + Utils.doubleToString(Utils.sum(m_HoldOutDist), 2) + "/" + 	  Utils.doubleToString((Utils.sum(m_HoldOutDist) - 				m_HoldOutDist[maxIndex]), 2) + "]";      }    }      /**     * Recursively outputs the tree.     */    protected String toString(int level, Tree parent) {      try {	StringBuffer text = new StringBuffer();      	if (m_Attribute == -1) {		  // Output leaf info	  return leafString(parent);	} else if (m_Info.attribute(m_Attribute).isNominal()) {		  // For nominal attributes	  for (int i = 0; i < m_Successors.length; i++) {	    text.append("\n");	    for (int j = 0; j < level; j++) {	      text.append("|   ");	    }	    text.append(m_Info.attribute(m_Attribute).name() + " = " +			m_Info.attribute(m_Attribute).value(i));	    text.append(m_Successors[i].toString(level + 1, this));	  }	} else {		  // For numeric attributes	  text.append("\n");	  for (int j = 0; j < level; j++) {	    text.append("|   ");	  }	  text.append(m_Info.attribute(m_Attribute).name() + " < " +		      Utils.doubleToString(m_SplitPoint, 2));	  text.append(m_Successors[0].toString(level + 1, this));	  text.append("\n");	  for (int j = 0; j < level; j++) {	    text.append("|   ");	  }	  text.append(m_Info.attribute(m_Attribute).name() + " >= " +		      Utils.doubleToString(m_SplitPoint, 2));	  text.append(m_Successors[1].toString(level + 1, this));	}      	return text.toString();      } catch (Exception e) {	e.printStackTrace();	return "Decision tree: tree can't be printed";      }    }         /**     * Recursively generates a tree.     */    protected void buildTree(int[][] sortedIndices, double[][] weights,			     Instances data, double totalWeight, 			     double[] classProbs, Instances header,			     double minNum, double minVariance,			     int depth, int maxDepth)       throws Exception {            // Store structure of dataset, set minimum number of instances      // and make space for potential info from pruning data      m_Info = header;      m_HoldOutDist = new double[data.numClasses()];            // Make leaf if there are no training instances      if (sortedIndices[0].length == 0) {	if (data.classAttribute().isNumeric()) {	  m_Distribution = new double[2];	} else {	  m_Distribution = new double[data.numClasses()];	}	m_ClassProbs = null;	return;      }            double priorVar = 0;      if (data.classAttribute().isNumeric()) {	// Compute prior variance	double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; 	for (int i = 0; i < sortedIndices[0].length; i++) {	  Instance inst = data.instance(sortedIndices[0][i]);	  totalSum += inst.classValue() * weights[0][i];	  totalSumSquared += 	    inst.classValue() * inst.classValue() * weights[0][i];	  totalSumOfWeights += weights[0][i];	}	priorVar = singleVariance(totalSum, totalSumSquared, 				  totalSumOfWeights);      }      // Check if node doesn't contain enough instances, is pure      // or the maximum tree depth is reached      m_ClassProbs = new double[classProbs.length];      System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);      if ((totalWeight < (2 * minNum)) ||	  // Nominal case	  (data.classAttribute().isNominal() &&	   Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],		    Utils.sum(m_ClassProbs))) ||	  // Numeric case	  (data.classAttribute().isNumeric() && 	   ((priorVar / totalWeight) < minVariance)) ||	  // Check tree depth	  ((m_MaxDepth >= 0) && (depth >= maxDepth))) {	// Make leaf	m_Attribute = -1;	if (data.classAttribute().isNominal()) {	  // Nominal case	  m_Distribution = new double[m_ClassProbs.length];	  for (int i = 0; i < m_ClassProbs.length; i++) {	    m_Distribution[i] = m_ClassProbs[i];	  }	  Utils.normalize(m_ClassProbs);	} else {	  // Numeric case	  m_Distribution = new double[2];	  m_Distribution[0] = priorVar;	  m_Distribution[1] = totalWeight;	}	return;      }      // Compute class distributions and value of splitting      // criterion for each attribute      double[] vals = new double[data.numAttributes()];      double[][][] dists = new double[data.numAttributes()][0][0];      double[][] props = new double[data.numAttributes()][0];      double[][] totalSubsetWeights = new double[data.numAttributes()][0];      double[] splits = new double[data.numAttributes()];      if (data.classAttribute().isNominal()) { 	// Nominal case	for (int i = 0; i < data.numAttributes(); i++) {	  if (i != data.classIndex()) {	    splits[i] = distribution(props, dists, i, sortedIndices[i], 				     weights[i], totalSubsetWeights, data);	    vals[i] = gain(dists[i], priorVal(dists[i]));	  }	}      } else {	// Numeric case	for (int i = 0; i < data.numAttributes(); i++) {	  if (i != data.classIndex()) {	    splits[i] = 	      numericDistribution(props, dists, i, sortedIndices[i], 				  weights[i], totalSubsetWeights, data, 				  vals);	  }	}      }      // Find best attribute      m_Attribute = Utils.maxIndex(vals);      int numAttVals = dists[m_Attribute].length;      // Check if there are at least two subsets with      // required minimum number of instances      int count = 0;      for (int i = 0; i < numAttVals; i++) {	if (totalSubsetWeights[m_Attribute][i] >= minNum) {	  count++;	}	if (count > 1) {	  break;	}      }      // Any useful split found?      if ((vals[m_Attribute] > 0) && (count > 1)) {	// Build subtrees	m_SplitPoint = splits[m_Attribute];	m_Prop = props[m_Attribute];	int[][][] subsetIndices = 	  new int[numAttVals][data.numAttributes()][0];	double[][][] subsetWeights = 	  new double[numAttVals][data.numAttributes()][0];	splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, 		  sortedIndices, weights, data);	m_Successors = new Tree[numAttVals];	for (int i = 0; i < numAttVals; i++) {	  m_Successors[i] = new Tree();	  m_Successors[i].	    buildTree(subsetIndices[i], subsetWeights[i], 		      data, totalSubsetWeights[m_Attribute][i],		      dists[m_Attribute][i], header, minNum, 		      minVariance, depth + 1, maxDepth);	}      } else {      	// Make leaf	m_Attribute = -1;      }      // Normalize class counts      if (data.classAttribute().isNominal()) {	m_Distribution = new double[m_ClassProbs.length];	for (int i = 0; i < m_ClassProbs.length; i++) {	    m_Distribution[i] = m_ClassProbs[i];	}	Utils.normalize(m_ClassProbs);      } else {	m_Distribution = new double[2];	m_Distribution[0] = priorVar;	m_Distribution[1] = totalWeight;      }    }    /**     * Computes size of the tree.     */    protected int numNodes() {          if (m_Attribute == -1) {	return 1;      } else {	int size = 1;	for (int i = 0; i < m_Successors.length; i++) {	  size += m_Successors[i].numNodes();	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -