⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 reptree.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
	totalSumOfWeights = currSumOfWeights[1];
	
	sums[1] = currSums[1];
	sumSquared[1] = currSumSquared[1];
	sumOfWeights[1] = currSumOfWeights[1];

	// Try all possible split points
	double currSplit = data.instance(sortedIndices[0]).value(att);
	double currVal, bestVal = Double.MAX_VALUE;
	for (i = 0; i < sortedIndices.length; i++) {
	  Instance inst = data.instance(sortedIndices[i]);
	  if (inst.isMissing(att)) {
	    break;
	  }
	  if (inst.value(att) > currSplit) {
	    currVal = variance(currSums, currSumSquared, currSumOfWeights);
	    if (currVal < bestVal) {
	      bestVal = currVal;
	      splitPoint = (inst.value(att) + currSplit) / 2.0;
	      for (int j = 0; j < 2; j++) {
		sums[j] = currSums[j];
		sumSquared[j] = currSumSquared[j];
		sumOfWeights[j] = currSumOfWeights[j];
	      }
	    } 
	  } 

	  currSplit = inst.value(att);

	  double classVal = inst.classValue() * weights[i];
	  double classValSquared = inst.classValue() * classVal;

	  currSums[0] += classVal;
	  currSumSquared[0] += classValSquared;
	  currSumOfWeights[0] += weights[i];

	  currSums[1] -= classVal;
	  currSumSquared[1] -= classValSquared;
	  currSumOfWeights[1] -= weights[i];
	}
      }

      // Compute weights
      props[att] = new double[sums.length];
      for (int k = 0; k < props[att].length; k++) {
	props[att][k] = sumOfWeights[k];
      }
      if (!(Utils.sum(props[att]) > 0)) {
	for (int k = 0; k < props[att].length; k++) {
	  props[att][k] = 1.0 / (double)props[att].length;
	}
      } else {
	Utils.normalize(props[att]);
      }
    
	
      // Distribute counts for missing values
      while (i < sortedIndices.length) {
	Instance inst = data.instance(sortedIndices[i]);
	for (int j = 0; j < sums.length; j++) {
	  sums[j] += props[att][j] * inst.classValue() * weights[i];
	  sumSquared[j] += props[att][j] * inst.classValue() * 
	    inst.classValue() * weights[i];
	  sumOfWeights[j] += props[att][j] * weights[i];
	}
	totalSum += inst.classValue() * weights[i];
	totalSumSquared += 
	  inst.classValue() * inst.classValue() * weights[i]; 
	totalSumOfWeights += weights[i];
	i++;
      }

      // Compute final distribution
      dist = new double[sums.length][data.numClasses()];
      for (int j = 0; j < sums.length; j++) {
	if (sumOfWeights[j] > 0) {
	  dist[j][0] = sums[j] / sumOfWeights[j];
	} else {
	  dist[j][0] = totalSum / totalSumOfWeights;
	}
      }
      
      // Compute variance gain
      double priorVar =
	singleVariance(totalSum, totalSumSquared, totalSumOfWeights);
      double var = variance(sums, sumSquared, sumOfWeights);
      double gain = priorVar - var;
      
      // Return distribution and split point
      subsetWeights[att] = sumOfWeights;
      dists[att] = dist;
      vals[att] = gain;
      return splitPoint;
    }      

    /**
     * Computes variance for subsets.
     */
    protected double variance(double[] s, double[] sS, 
			    double[] sumOfWeights) {
      
      double var = 0;
      
      for (int i = 0; i < s.length; i++) {
	if (sumOfWeights[i] > 0) {
	  var += singleVariance(s[i], sS[i], sumOfWeights[i]);
	}
      }
      
      return var;
    }
    
    /** 
     * Computes the variance for a single set
     */
    protected double singleVariance(double s, double sS, double weight) {
      
      return sS - ((s * s) / weight);
    }

    /**
     * Computes value of splitting criterion before split.
     */
    protected double priorVal(double[][] dist) {

      return ContingencyTables.entropyOverColumns(dist);
    }

    /**
     * Computes value of splitting criterion after split.
     */
    protected double gain(double[][] dist, double priorVal) {

      return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
    }

    /**
     * Prunes the tree using the hold-out data (bottom-up).
     */
    protected double reducedErrorPrune() throws Exception {

      // Is node leaf ? 
      if (m_Attribute == -1) {
	return m_HoldOutError;
      }

      // Prune all sub trees
      double errorTree = 0;
      for (int i = 0; i < m_Successors.length; i++) {
	errorTree += m_Successors[i].reducedErrorPrune();
      }

      // Replace sub tree with leaf if error doesn't get worse
      if (errorTree >= m_HoldOutError) {
	m_Attribute = -1;
	m_Successors = null;
	return m_HoldOutError;
      } else {
	return errorTree;
      }
    }

    /**
     * Inserts hold-out set into tree.
     */
    protected void insertHoldOutSet(Instances data) throws Exception {

      for (int i = 0; i < data.numInstances(); i++) {
	insertHoldOutInstance(data.instance(i), data.instance(i).weight(),
			      this);
      }
    }

    /**
     * Inserts an instance from the hold-out set into the tree.
     */
    protected void insertHoldOutInstance(Instance inst, double weight, 
					 Tree parent) throws Exception {
      
      // Insert instance into hold-out class distribution
      if (inst.classAttribute().isNominal()) {
	
	// Nominal case
	m_HoldOutDist[(int)inst.classValue()] += weight;
	int predictedClass = 0;
	if (m_ClassProbs == null) {
	  predictedClass = Utils.maxIndex(parent.m_ClassProbs);
	} else {
	  predictedClass = Utils.maxIndex(m_ClassProbs);
	}
	if (predictedClass != (int)inst.classValue()) {
	  m_HoldOutError += weight;
	}
      } else {
	
	// Numeric case
	m_HoldOutDist[0] += weight;
	double diff = 0;
	if (m_ClassProbs == null) {
	  diff = parent.m_ClassProbs[0] - inst.classValue();
	} else {
	  diff =  m_ClassProbs[0] - inst.classValue();
	}
	m_HoldOutError += diff * diff * weight;
      }	
      
      // The process is recursive
      if (m_Attribute != -1) {
	
	// If node is not a leaf
	if (inst.isMissing(m_Attribute)) {
	  
	  // Distribute instance
	  for (int i = 0; i < m_Successors.length; i++) {
	    if (m_Prop[i] > 0) {
	      m_Successors[i].insertHoldOutInstance(inst, weight * 
						    m_Prop[i], this);
	    }
	  }
	} else {
	  
	  if (m_Info.attribute(m_Attribute).isNominal()) {
	    
	    // Treat nominal attributes
	    m_Successors[(int)inst.value(m_Attribute)].
	      insertHoldOutInstance(inst, weight, this);
	  } else {
	    
	    // Treat numeric attributes
	    if (inst.value(m_Attribute) < m_SplitPoint) {
	      m_Successors[0].insertHoldOutInstance(inst, weight, this);
	    } else {
	      m_Successors[1].insertHoldOutInstance(inst, weight, this);
	    }
	  }
	}
      }
    }
  
    /**
     * Inserts hold-out set into tree.
     */
    protected void backfitHoldOutSet(Instances data) throws Exception {
      
      for (int i = 0; i < data.numInstances(); i++) {
	backfitHoldOutInstance(data.instance(i), data.instance(i).weight(),
			       this);
      }
    }
    
    /**
     * Inserts an instance from the hold-out set into the tree.
     */
    protected void backfitHoldOutInstance(Instance inst, double weight, 
					  Tree parent) throws Exception {
      
      // Insert instance into hold-out class distribution
      if (inst.classAttribute().isNominal()) {
	
	// Nominal case
	if (m_ClassProbs == null) {
	  m_ClassProbs = new double[inst.numClasses()];
	}
	System.arraycopy(m_Distribution, 0, m_ClassProbs, 0, inst.numClasses());
	m_ClassProbs[(int)inst.classValue()] += weight;
	Utils.normalize(m_ClassProbs);
      } else {
	
	// Numeric case
	if (m_ClassProbs == null) {
	  m_ClassProbs = new double[1];
	}
	m_ClassProbs[0] *= m_Distribution[1];
	m_ClassProbs[0] += weight * inst.classValue();
	m_ClassProbs[0] /= (m_Distribution[1] + weight);
      }	
      
      // The process is recursive
      if (m_Attribute != -1) {
	
	// If node is not a leaf
	if (inst.isMissing(m_Attribute)) {
	  
	  // Distribute instance
	  for (int i = 0; i < m_Successors.length; i++) {
	    if (m_Prop[i] > 0) {
	      m_Successors[i].backfitHoldOutInstance(inst, weight * 
						     m_Prop[i], this);
	    }
	  }
	} else {
	  
	  if (m_Info.attribute(m_Attribute).isNominal()) {
	    
	    // Treat nominal attributes
	    m_Successors[(int)inst.value(m_Attribute)].
	      backfitHoldOutInstance(inst, weight, this);
	  } else {
	    
	    // Treat numeric attributes
	    if (inst.value(m_Attribute) < m_SplitPoint) {
	      m_Successors[0].backfitHoldOutInstance(inst, weight, this);
	    } else {
	      m_Successors[1].backfitHoldOutInstance(inst, weight, this);
	    }
	  }
	}
      }
    }
  }

  /** The Tree object */
  protected Tree m_Tree = null;
    
  /** Number of folds for reduced error pruning. */
  protected int m_NumFolds = 3;
    
  /** Seed for random data shuffling. */
  protected int m_Seed = 1;
    
  /** Don't prune */
  protected boolean m_NoPruning = false;

  /** The minimum number of instances per leaf. */
  protected double m_MinNum = 2;

  /** The minimum proportion of the total variance (over all the data)
      required for split. */
  protected double m_MinVarianceProp = 1e-3;

  /** Upper bound on the tree depth */
  protected int m_MaxDepth = -1;
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String noPruningTipText() {
    return "Whether pruning is performed.";
  }
  
  /**
   * Get the value of NoPruning.
   *
   * @return Value of NoPruning.
   */
  public boolean getNoPruning() {
    
    return m_NoPruning;
  }
  
  /**
   * Set the value of NoPruning.
   *
   * @param newNoPruning Value to assign to NoPruning.
   */
  public void setNoPruning(boolean newNoPruning) {
    
    m_NoPruning = newNoPruning;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String minNumTipText() {
    return "The minimum total weight of the instances in a leaf.";
  }

  /**
   * Get the value of MinNum.
   *
   * @return Value of MinNum.
   */
  public double getMinNum() {
    
    return m_MinNum;
  }
  
  /**
   * Set the value of MinNum.
   *
   * @param newMinNum Value to assign to MinNum.
   */
  public void setMinNum(double newMinNum) {
    
    m_MinNum = newMinNum;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String minVariancePropTipText() {
    return "The minimum proportion of the variance on all the data " +
      "that needs to be present at a node in order for splitting to " +
      "be performed in regression trees.";
  }

  /**
   * Get the value of MinVarianceProp.
   *
   * @return Value of MinVarianceProp.
   */
  public double getMinVarianceProp() {
    
    return m_MinVarianceProp;
  }
  
  /**
   * Set the value of MinVarianceProp.
   *
   * @param newMinVarianceProp Value to assign to MinVarianceProp.
   */
  public void setMinVarianceProp(double newMinVarianceProp) {
    
    m_MinVarianceProp = newMinVarianceProp;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String seedTipText() {
    return "The seed used for randomizing the data.";
  }

  /**
   * Get the value of Seed.
   *
   * @return Value of Seed.
   */
  public int getSeed() {
    
    return m_Seed;
  }
  
  /**
   * Set the value of Seed.
   *

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -