⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 reptree.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
   * @param newSeed Value to assign to Seed.
   */
  public void setSeed(int newSeed) {
    
    m_Seed = newSeed;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String numFoldsTipText() {
    return "Determines the amount of data used for pruning. One fold is used for "
      + "pruning, the rest for growing the rules.";
  }
  
  /**
   * Get the value of NumFolds.
   *
   * @return Value of NumFolds.
   */
  public int getNumFolds() {
    
    return m_NumFolds;
  }
  
  /**
   * Set the value of NumFolds.
   *
   * @param newNumFolds Value to assign to NumFolds.
   */
  public void setNumFolds(int newNumFolds) {
    
    m_NumFolds = newNumFolds;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String maxDepthTipText() {
    return "The maximum tree depth (-1 for no restriction).";
  }

  /**
   * Get the value of MaxDepth.
   *
   * @return Value of MaxDepth.
   */
  public int getMaxDepth() {
    
    return m_MaxDepth;
  }
  
  /**
   * Set the value of MaxDepth.
   *
   * @param newMaxDepth Value to assign to MaxDepth.
   */
  public void setMaxDepth(int newMaxDepth) {
    
    m_MaxDepth = newMaxDepth;
  }
  
  /**
   * Lists the command-line options for this classifier.
   */
  public Enumeration listOptions() {
    
    Vector newVector = new Vector(5);

    newVector.
      addElement(new Option("\tSet minimum number of instances per leaf " +
			    "(default 2).",
			    "M", 1, "-M <minimum number of instances>"));
    newVector.
      addElement(new Option("\tSet minimum numeric class variance proportion\n" +
			    "\tof train variance for split (default 1e-3).",
			    "V", 1, "-V <minimum variance for split>"));
    newVector.
      addElement(new Option("\tNumber of folds for reduced error pruning " +
			    "(default 3).",
			    "N", 1, "-N <number of folds>"));
    newVector.
      addElement(new Option("\tSeed for random data shuffling (default 1).",
			    "S", 1, "-S <seed>"));
    newVector.
      addElement(new Option("\tNo pruning.",
			    "P", 0, "-P"));
    newVector.
      addElement(new Option("\tMaximum tree depth (default -1, no maximum)",
			    "L", 1, "-L"));

    return newVector.elements();
  } 

  /**
   * Gets options from this classifier.
   */
  public String[] getOptions() {
    
    String [] options = new String [12];
    int current = 0;
    options[current++] = "-M"; 
    options[current++] = "" + (int)getMinNum();
    options[current++] = "-V"; 
    options[current++] = "" + getMinVarianceProp();
    options[current++] = "-N"; 
    options[current++] = "" + getNumFolds();
    options[current++] = "-S"; 
    options[current++] = "" + getSeed();
    options[current++] = "-L"; 
    options[current++] = "" + getMaxDepth();
    if (getNoPruning()) {
      options[current++] = "-P";
    }
    while (current < options.length) {
      options[current++] = "";
    }
    return options;
  }

  /**
   * Parses a given list of options.
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    
    String minNumString = Utils.getOption('M', options);
    if (minNumString.length() != 0) {
      m_MinNum = (double)Integer.parseInt(minNumString);
    } else {
      m_MinNum = 2;
    }
    String minVarString = Utils.getOption('V', options);
    if (minVarString.length() != 0) {
      m_MinVarianceProp = Double.parseDouble(minVarString);
    } else {
      m_MinVarianceProp = 1e-3;
    }
    String numFoldsString = Utils.getOption('N', options);
    if (numFoldsString.length() != 0) {
      m_NumFolds = Integer.parseInt(numFoldsString);
    } else {
      m_NumFolds = 3;
    }
    String seedString = Utils.getOption('S', options);
    if (seedString.length() != 0) {
      m_Seed = Integer.parseInt(seedString);
    } else {
      m_Seed = 1;
    }
    m_NoPruning = Utils.getFlag('P', options);
    String depthString = Utils.getOption('L', options);
    if (depthString.length() != 0) {
      m_MaxDepth = Integer.parseInt(depthString);
    } else {
      m_MaxDepth = -1;
    }
    Utils.checkForRemainingOptions(options);
  }
  
  /**
   * Computes size of the tree.
   */
  public int numNodes() {

    return m_Tree.numNodes();
  }

  /**
   * Returns an enumeration of the additional measure names.
   *
   * @return an enumeration of the measure names
   */
  public Enumeration enumerateMeasures() {
    
    Vector newVector = new Vector(1);
    newVector.addElement("measureTreeSize");
    return newVector.elements();
  }
 
  /**
   * Returns the value of the named measure.
   *
   * @param measureName the name of the measure to query for its value
   * @return the value of the named measure
   * @exception IllegalArgumentException if the named measure is not supported
   */
  public double getMeasure(String additionalMeasureName) {
    
    if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) {
      return (double) numNodes();
    }
    else {throw new IllegalArgumentException(additionalMeasureName 
			      + " not supported (REPTree)");
    }
  }

  /**
   * Builds classifier.
   */
  public void buildClassifier(Instances data) throws Exception {

    Random random = new Random(m_Seed);

    // Check for non-nominal classes
    if (!data.classAttribute().isNominal() && !data.classAttribute().isNumeric()) {
      throw new UnsupportedClassTypeException("REPTree: nominal or numeric class!");
    }

    // Delete instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    // Check for empty datasets
    if (data.numInstances() == 0) {
      throw new IllegalArgumentException("REPTree: zero training instances or all " +
					 "instances have missing class!");
    }

    if (data.checkForStringAttributes()) {
      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
    }

    m_zeroR = null;
    if (data.numAttributes() == 1) {
      m_zeroR = new ZeroR();
      m_zeroR.buildClassifier(data);
      return;
    }

    // Randomize and stratify
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
      data.stratify(m_NumFolds);
    }

    // Split data into training and pruning set
    Instances train = null;
    Instances prune = null;
    if (!m_NoPruning) {
      train = data.trainCV(m_NumFolds, 0, random);
      prune = data.testCV(m_NumFolds, 0);
    } else {
      train = data;
    }

    // Create array of sorted indices and weights
    int[][] sortedIndices = new int[train.numAttributes()][0];
    double[][] weights = new double[train.numAttributes()][0];
    double[] vals = new double[train.numInstances()];
    for (int j = 0; j < train.numAttributes(); j++) {
      if (j != train.classIndex()) {
	weights[j] = new double[train.numInstances()];
	if (train.attribute(j).isNominal()) {

	  // Handling nominal attributes. Putting indices of
	  // instances with missing values at the end.
	  sortedIndices[j] = new int[train.numInstances()];
	  int count = 0;
	  for (int i = 0; i < train.numInstances(); i++) {
	    Instance inst = train.instance(i);
	    if (!inst.isMissing(j)) {
	      sortedIndices[j][count] = i;
	      weights[j][count] = inst.weight();
	      count++;
	    }
	  }
	  for (int i = 0; i < train.numInstances(); i++) {
	    Instance inst = train.instance(i);
	    if (inst.isMissing(j)) {
	      sortedIndices[j][count] = i;
	      weights[j][count] = inst.weight();
	      count++;
	    }
	  }
	} else {

	  // Sorted indices are computed for numeric attributes
	  for (int i = 0; i < train.numInstances(); i++) {
	    Instance inst = train.instance(i);
	    vals[i] = inst.value(j);
	  }
	  sortedIndices[j] = Utils.sort(vals);
	  for (int i = 0; i < train.numInstances(); i++) {
	    weights[j][i] = train.instance(sortedIndices[j][i]).weight();
	  }
	}
      }
    }

    // Compute initial class counts
    double[] classProbs = new double[train.numClasses()];
    double totalWeight = 0, totalSumSquared = 0;
    for (int i = 0; i < train.numInstances(); i++) {
      Instance inst = train.instance(i);
      if (data.classAttribute().isNominal()) {
	classProbs[(int)inst.classValue()] += inst.weight();
	totalWeight += inst.weight();
      } else {
	classProbs[0] += inst.classValue() * inst.weight();
	totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
	totalWeight += inst.weight();
      }
    }
    m_Tree = new Tree();
    double trainVariance = 0;
    if (data.classAttribute().isNumeric()) {
      trainVariance = m_Tree.
	singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
      classProbs[0] /= totalWeight;
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs,
		     new Instances(train, 0), m_MinNum, m_MinVarianceProp * 
		     trainVariance, 0, m_MaxDepth);
    
    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
      m_Tree.insertHoldOutSet(prune);
      m_Tree.reducedErrorPrune();
      m_Tree.backfitHoldOutSet(prune);
    }
  }

  /**
   * Computes class distribution of an instance using the tree.
   */
  public double[] distributionForInstance(Instance instance) 
    throws Exception {
      
      if (m_zeroR != null) {
	return m_zeroR.distributionForInstance(instance);
      } else {
	return m_Tree.distributionForInstance(instance);
      }
  }


  /** 
   * For getting a unique ID when outputting the tree source
   * (hashcode isn't guaranteed unique) 
   */
  private static long PRINTED_NODES = 0;

  /**
   * Gets the next unique node ID.
   *
   * @return the next unique node ID.
   */
  protected static long nextID() {

    return PRINTED_NODES ++;
  }

  protected static void resetID() {
    PRINTED_NODES = 0;
  }

  /**
   * Returns the tree as if-then statements.
   *
   * @return the tree as a Java if-then type statement
   * @exception Exception if something goes wrong
   */
  public String toSource(String className) 
    throws Exception {
     
    if (m_Tree == null) {
      throw new Exception("REPTree: No model built yet.");
    } 
    StringBuffer [] source = m_Tree.toSource(className, m_Tree);
    return
    "class " + className + " {\n\n"
    +"  public static double classify(Object [] i)\n"
    +"    throws Exception {\n\n"
    +"    double p = Double.NaN;\n"
    + source[0]  // Assignment code
    +"    return p;\n"
    +"  }\n"
    + source[1]  // Support code
    +"}\n";
  }

  /**
   *  Returns the type of graph this classifier
   *  represents.
   *  @return Drawable.TREE
   */   
  public int graphType() {
      return Drawable.TREE;
  }

  /**
   * Outputs the decision tree as a graph
   */
  public String graph() throws Exception {

    if (m_Tree == null) {
      throw new Exception("REPTree: No model built yet.");
    } 
    StringBuffer resultBuff = new StringBuffer();
    m_Tree.toGraph(resultBuff, 0, null);
    String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString()
      + "\n}\n";
    return result;
  }
  
  /**
   * Outputs the decision tree.
   */
  public String toString() {

    if (m_zeroR != null) {
      return "No attributes other than class. Using ZeroR.\n\n" + m_zeroR.toString();
    }
    if ((m_Tree == null)) {
      return "REPTree: No model built yet.";
    } 
    return     
      "\nREPTree\n============\n" + m_Tree.toString(0, null) + "\n" +
      "\nSize of the tree : " + numNodes();
  }

  /**
   * Main method for this class.
   */
  public static void main(String[] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new REPTree(), argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -