⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 randomtree.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    RandomTree.java *    Copyright (C) 2001 Richard Kirkby, Eibe Frank * */package weka.classifiers.trees;import weka.classifiers.*;import weka.core.*;import java.util.*;/** * Class for constructing a tree that considers K random features at each node. * Performs no pruning. * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class RandomTree extends DistributionClassifier   implements OptionHandler, WeightedInstancesHandler, Randomizable {  /** The subtrees appended to this tree. */   protected RandomTree[] m_Successors;      /** The attribute to split on. */  protected int m_Attribute = -1;      /** The split point. */  protected double m_SplitPoint = Double.NaN;      /** The class distribution from the training data. */  protected double[][] m_Distribution = null;      /** The header information. */  protected Instances m_Info = null;      /** The proportions of training instances going down each branch. */  protected double[] m_Prop = null;      /** Class probabilities from the training data. */  protected double[] m_ClassProbs = null;      /** Minimum number of instances for leaf. */  protected double m_MinNum = 1.0;      /** Debug info */  protected boolean m_Debug = false;    /** The number of attributes considered for a split. */  protected int m_KValue = 1;  /** Random number generator. */  protected Random m_random;  /** The random seed to use. */  protected int m_randomSeed = 1;  /**   * Get the value of MinNum.   *   * @return Value of MinNum.   */  public double getMinNum() {        return m_MinNum;  }    /**   * Set the value of MinNum.   *   * @param newMinNum Value to assign to MinNum.   */  public void setMinNum(double newMinNum) {        m_MinNum = newMinNum;  }    /**   * Get the value of K.   *   * @return Value of K.   */  public int getKValue() {        return m_KValue;  }    /**   * Set the value of K.   *   * @param k Value to assign to K.   */  public void setKValue(int k) {        m_KValue = k;  }  /**   * Get the value of Debug.   *   * @return Value of Debug.   */  public boolean getDebug() {        return m_Debug;  }    /**   * Set the value of Debug.   *   * @param newDebug Value to assign to Debug.   */  public void setDebug(boolean newDebug) {        m_Debug = newDebug;  }  /**   * Set the seed for random number generation.   *   * @param seed the seed    */  public void setSeed(int seed) {    m_randomSeed = seed;  }    /**   * Gets the seed for the random number generations   *   * @return the seed for the random number generation   */  public int getSeed() {    return m_randomSeed;  }    /**   * Lists the command-line options for this classifier.   */  public Enumeration listOptions() {        Vector newVector = new Vector(6);    newVector.      addElement(new Option("\tNumber of attributes to randomly investigate.",			    "K", 1, "-K <number of attributes>"));    newVector.      addElement(new Option("\tSet minimum number of instances per leaf.",			    "M", 1, "-M <minimum number of instances>"));    newVector.      addElement(new Option("\tTurns debugging info on.",			    "D", 0, "-D"));    newVector      .addElement(new Option("\tSeed for random number generator.\n"			     + "\t(default 1)",			     "S", 1, "-S"));    return newVector.elements();  }   /**   * Gets options from this classifier.   */  public String[] getOptions() {        String [] options = new String [10];    int current = 0;    options[current++] = "-K";     options[current++] = "" + getKValue();    options[current++] = "-M";     options[current++] = "" + getMinNum();    options[current++] = "-S";    options[current++] = "" + getSeed();    if (getDebug()) {      options[current++] = "-D";    }    while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Parses a given list of options.   * @param options the list of options as an array of strings   * @exception Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception{        String kValueString = Utils.getOption('K', options);    if (kValueString.length() != 0) {      m_KValue = Integer.parseInt(kValueString);    } else {      m_KValue = 1;    }    String minNumString = Utils.getOption('M', options);    if (minNumString.length() != 0) {      m_MinNum = (double)Integer.parseInt(minNumString);    } else {      m_MinNum = 1;    }    String seed = Utils.getOption('S', options);    if (seed.length() != 0) {      setSeed(Integer.parseInt(seed));    } else {      setSeed(1);    }    m_Debug = Utils.getFlag('D', options);    Utils.checkForRemainingOptions(options);  }  /**   * Builds classifier.   */  public void buildClassifier(Instances data) throws Exception {    // Make sure K value is in range    if (m_KValue > data.numAttributes()-1) m_KValue = data.numAttributes()-1;    // Delete instances with missing class    data = new Instances(data);    data.deleteWithMissingClass();    Instances train = data;    // Create array of sorted indices and weights    int[][] sortedIndices = new int[train.numAttributes()][0];    double[][] weights = new double[train.numAttributes()][0];    double[] vals = new double[train.numInstances()];    for (int j = 0; j < train.numAttributes(); j++) {      if (j != train.classIndex()) {	weights[j] = new double[train.numInstances()];	if (train.attribute(j).isNominal()) {	  // Handling nominal attributes. Putting indices of	  // instances with missing values at the end.	  sortedIndices[j] = new int[train.numInstances()];	  int count = 0;	  for (int i = 0; i < train.numInstances(); i++) {	    Instance inst = train.instance(i);	    if (!inst.isMissing(j)) {	      sortedIndices[j][count] = i;	      weights[j][count] = inst.weight();	      count++;	    }	  }	  for (int i = 0; i < train.numInstances(); i++) {	    Instance inst = train.instance(i);	    if (inst.isMissing(j)) {	      sortedIndices[j][count] = i;	      weights[j][count] = inst.weight();	      count++;	    }	  }	} else {	  	  // Sorted indices are computed for numeric attributes	  for (int i = 0; i < train.numInstances(); i++) {	    Instance inst = train.instance(i);	    vals[i] = inst.value(j);	  }	  sortedIndices[j] = Utils.sort(vals);	  for (int i = 0; i < train.numInstances(); i++) {	    weights[j][i] = train.instance(sortedIndices[j][i]).weight();	  }	}      }    }    // Compute initial class counts    double[] classProbs = new double[train.numClasses()];    for (int i = 0; i < train.numInstances(); i++) {      Instance inst = train.instance(i);      classProbs[(int)inst.classValue()] += inst.weight();    }    // Create the attribute indices window    int[] attIndicesWindow = new int[data.numAttributes()-1];    int j=0;    for (int i=0; i<attIndicesWindow.length; i++) {      if (j == data.classIndex()) j++; // do not include the class      attIndicesWindow[i] = j++;    }    // Build tree    buildTree(sortedIndices, weights, train, classProbs,	      new Instances(train, 0), m_MinNum, m_Debug,	      attIndicesWindow);  }    /**   * Computes class distribution of an instance using the decision tree.   */  public double[] distributionForInstance(Instance instance) throws Exception {        double[] returnedDist = null;        if (m_Attribute > -1) {            // Node is not a leaf      if (instance.isMissing(m_Attribute)) {	// Value is missing	returnedDist = new double[m_Info.numClasses()];	// Split instance up	for (int i = 0; i < m_Successors.length; i++) {	  double[] help = m_Successors[i].distributionForInstance(instance);	  if (help != null) {	    for (int j = 0; j < help.length; j++) {	      returnedDist[j] += m_Prop[i] * help[j];	    }	  }	}      } else if (m_Info.attribute(m_Attribute).isNominal()) {	  	// For nominal attributes	returnedDist =  m_Successors[(int)instance.value(m_Attribute)].	  distributionForInstance(instance);      } else {		// For numeric attributes	if (Utils.sm(instance.value(m_Attribute), m_SplitPoint)) {	  returnedDist = m_Successors[0].distributionForInstance(instance);	} else {	  returnedDist = m_Successors[1].distributionForInstance(instance);	}      }    }    if ((m_Attribute == -1) || (returnedDist == null)) {      // Node is a leaf or successor is empty      return m_ClassProbs;    } else {      return returnedDist;    }  }  /**   * Outputs the decision tree as a graph   */  public String toGraph() {    try {      StringBuffer resultBuff = new StringBuffer();      toGraph(resultBuff, 0);      String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString()	+ "\n}\n";      return result;    } catch (Exception e) {      return null;    }  }    /**   * Outputs one node for graph.   */  public int toGraph(StringBuffer text, int num) throws Exception {        int maxIndex = Utils.maxIndex(m_ClassProbs);    String classValue = m_Info.classAttribute().value(maxIndex);        num++;    if (m_Attribute == -1) {      text.append("N" + Integer.toHexString(hashCode()) +		  " [label=\"" + num + ": " + classValue + "\"" +		  "shape=box]\n");    }else {      text.append("N" + Integer.toHexString(hashCode()) +		  " [label=\"" + num + ": " + classValue + "\"]\n");      for (int i = 0; i < m_Successors.length; i++) {	text.append("N" + Integer.toHexString(hashCode()) 		    + "->" + 		    "N" + Integer.toHexString(m_Successors[i].hashCode())  +		    " [label=\"" + m_Info.attribute(m_Attribute).name());	if (m_Info.attribute(m_Attribute).isNumeric()) {	  if (i == 0) {	    text.append(" < " +			Utils.doubleToString(m_SplitPoint, 2));	  } else {	    text.append(" >= " +			Utils.doubleToString(m_SplitPoint, 2));	  }	} else {	  text.append(" = " + m_Info.attribute(m_Attribute).value(i));	}	text.append("\"]\n");	num = m_Successors[i].toGraph(text, num);      }    }        return num;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -