📄 reptree.java
字号:
// Try all possible split points double currSplit = data.instance(sortedIndices[0]).value(att); double currVal, bestVal = Double.MAX_VALUE; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (inst.value(att) > currSplit) { currVal = variance(currSums, currSumSquared, currSumOfWeights); if (currVal < bestVal) { bestVal = currVal; splitPoint = (inst.value(att) + currSplit) / 2.0; for (int j = 0; j < 2; j++) { sums[j] = currSums[j]; sumSquared[j] = currSumSquared[j]; sumOfWeights[j] = currSumOfWeights[j]; } } } currSplit = inst.value(att); double classVal = inst.classValue() * weights[i]; double classValSquared = inst.classValue() * classVal; currSums[0] += classVal; currSumSquared[0] += classValSquared; currSumOfWeights[0] += weights[i]; currSums[1] -= classVal; currSumSquared[1] -= classValSquared; currSumOfWeights[1] -= weights[i]; } } // Compute weights props[att] = new double[sums.length]; for (int k = 0; k < props[att].length; k++) { props[att][k] = sumOfWeights[k]; } if (!(Utils.sum(props[att]) > 0)) { for (int k = 0; k < props[att].length; k++) { props[att][k] = 1.0 / (double)props[att].length; } } else { Utils.normalize(props[att]); } // Distribute counts for missing values while (i < sortedIndices.length) { Instance inst = data.instance(sortedIndices[i]); for (int j = 0; j < sums.length; j++) { sums[j] += props[att][j] * inst.classValue() * weights[i]; sumSquared[j] += props[att][j] * inst.classValue() * inst.classValue() * weights[i]; sumOfWeights[j] += props[att][j] * weights[i]; } totalSum += inst.classValue() * weights[i]; totalSumSquared += inst.classValue() * inst.classValue() * weights[i]; totalSumOfWeights += weights[i]; i++; } // Compute final distribution dist = new double[sums.length][data.numClasses()]; for (int j = 0; j < sums.length; j++) { if (sumOfWeights[j] > 0) { dist[j][0] = sums[j] / sumOfWeights[j]; } else { dist[j][0] = totalSum / totalSumOfWeights; } } // Compute variance gain double priorVar = singleVariance(totalSum, totalSumSquared, totalSumOfWeights); double var = variance(sums, sumSquared, sumOfWeights); double gain = priorVar - var; // Return distribution and split point subsetWeights[att] = sumOfWeights; dists[att] = dist; vals[att] = gain; return splitPoint; } /** * Computes variance for subsets. * * @param s * @param sS * @param sumOfWeights * @return the variance */ protected double variance(double[] s, double[] sS, double[] sumOfWeights) { double var = 0; for (int i = 0; i < s.length; i++) { if (sumOfWeights[i] > 0) { var += singleVariance(s[i], sS[i], sumOfWeights[i]); } } return var; } /** * Computes the variance for a single set * * @param s * @param sS * @param weight the weight * @return the variance */ protected double singleVariance(double s, double sS, double weight) { return sS - ((s * s) / weight); } /** * Computes value of splitting criterion before split. * * @param dist * @return the splitting criterion */ protected double priorVal(double[][] dist) { return ContingencyTables.entropyOverColumns(dist); } /** * Computes value of splitting criterion after split. * * @param dist * @param priorVal the splitting criterion * @return the gain after splitting */ protected double gain(double[][] dist, double priorVal) { return priorVal - ContingencyTables.entropyConditionedOnRows(dist); } /** * Prunes the tree using the hold-out data (bottom-up). * * @return the error * @throws Exception if pruning fails for some reason */ protected double reducedErrorPrune() throws Exception { // Is node leaf ? if (m_Attribute == -1) { return m_HoldOutError; } // Prune all sub trees double errorTree = 0; for (int i = 0; i < m_Successors.length; i++) { errorTree += m_Successors[i].reducedErrorPrune(); } // Replace sub tree with leaf if error doesn't get worse if (errorTree >= m_HoldOutError) { m_Attribute = -1; m_Successors = null; return m_HoldOutError; } else { return errorTree; } } /** * Inserts hold-out set into tree. * * @param data the data to insert * @throws Exception if something goes wrong */ protected void insertHoldOutSet(Instances data) throws Exception { for (int i = 0; i < data.numInstances(); i++) { insertHoldOutInstance(data.instance(i), data.instance(i).weight(), this); } } /** * Inserts an instance from the hold-out set into the tree. * * @param inst the instance to insert * @param weight the weight of the instance * @param parent the parent of the node * @throws Exception if insertion fails */ protected void insertHoldOutInstance(Instance inst, double weight, Tree parent) throws Exception { // Insert instance into hold-out class distribution if (inst.classAttribute().isNominal()) { // Nominal case m_HoldOutDist[(int)inst.classValue()] += weight; int predictedClass = 0; if (m_ClassProbs == null) { predictedClass = Utils.maxIndex(parent.m_ClassProbs); } else { predictedClass = Utils.maxIndex(m_ClassProbs); } if (predictedClass != (int)inst.classValue()) { m_HoldOutError += weight; } } else { // Numeric case m_HoldOutDist[0] += weight; double diff = 0; if (m_ClassProbs == null) { diff = parent.m_ClassProbs[0] - inst.classValue(); } else { diff = m_ClassProbs[0] - inst.classValue(); } m_HoldOutError += diff * diff * weight; } // The process is recursive if (m_Attribute != -1) { // If node is not a leaf if (inst.isMissing(m_Attribute)) { // Distribute instance for (int i = 0; i < m_Successors.length; i++) { if (m_Prop[i] > 0) { m_Successors[i].insertHoldOutInstance(inst, weight * m_Prop[i], this); } } } else { if (m_Info.attribute(m_Attribute).isNominal()) { // Treat nominal attributes m_Successors[(int)inst.value(m_Attribute)]. insertHoldOutInstance(inst, weight, this); } else { // Treat numeric attributes if (inst.value(m_Attribute) < m_SplitPoint) { m_Successors[0].insertHoldOutInstance(inst, weight, this); } else { m_Successors[1].insertHoldOutInstance(inst, weight, this); } } } } } /** * Inserts hold-out set into tree. * * @param data the data to insert * @throws Exception if insertion fails */ protected void backfitHoldOutSet(Instances data) throws Exception { for (int i = 0; i < data.numInstances(); i++) { backfitHoldOutInstance(data.instance(i), data.instance(i).weight(), this); } } /** * Inserts an instance from the hold-out set into the tree. * * @param inst the instance to insert * @param weight the weight of the instance * @param parent the parent node * @throws Exception if insertion fails */ protected void backfitHoldOutInstance(Instance inst, double weight, Tree parent) throws Exception { // Insert instance into hold-out class distribution if (inst.classAttribute().isNominal()) { // Nominal case if (m_ClassProbs == null) { m_ClassProbs = new double[inst.numClasses()]; } System.arraycopy(m_Distribution, 0, m_ClassProbs, 0, inst.numClasses()); m_ClassProbs[(int)inst.classValue()] += weight; Utils.normalize(m_ClassProbs); } else { // Numeric case if (m_ClassProbs == null) { m_ClassProbs = new double[1]; } m_ClassProbs[0] *= m_Distribution[1]; m_ClassProbs[0] += weight * inst.classValue(); m_ClassProbs[0] /= (m_Distribution[1] + weight); } // The process is recursive if (m_Attribute != -1) { // If node is not a leaf if (inst.isMissing(m_Attribute)) { // Distribute instance for (int i = 0; i < m_Successors.length; i++) { if (m_Prop[i] > 0) { m_Successors[i].backfitHoldOutInstance(inst, weight * m_Prop[i], this); } } } else { if (m_Info.attribute(m_Attribute).isNominal()) { // Treat nominal attributes m_Successors[(int)inst.value(m_Attribute)]. backfitHoldOutInstance(inst, weight, this); } else { // Treat numeric attributes if (inst.value(m_Attribute) < m_SplitPoint) { m_Successors[0].backfitHoldOutInstance(inst, weight, this); } else { m_Successors[1].backfitHoldOutInstance(inst, weight, this); } } } } } } /** The Tree object */ protected Tree m_Tree = null; /** Number of folds for reduced error pruning. */ protected int m_NumFolds = 3; /** Seed for random data shuffling. */ protected int m_Seed = 1; /** Don't prune */ protected boolean m_NoPruning = false; /** The minimum number of instances per leaf. */ protected double m_MinNum = 2; /** The minimum proportion of the total variance (over all the data) required for split. */ protected double m_MinVarianceProp = 1e-3; /** Upper bound on the tree depth */ protected int m_MaxDepth = -1; /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String noPruningTipText() { return "Whether pruning is performed."; } /** * Get the value of NoPruning. * * @return Value of NoPruning. */ public boolean getNoPruning() { return m_NoPruning; } /** * Set the value of NoPruning. * * @param newNoPruning Value to assign to NoPruning. */ public void setNoPruning(boolean newNoPruning) { m_NoPruning = newNoPruning; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String minNumTipText() { return "The minimum total weight of the instances in a leaf."; } /** * Get the value of MinNum. * * @return Value of MinNum. */ public double getMinNum() { return m_MinNum; } /** * Set the value of MinNum. * * @param newMinNum Value to assign to MinNum. */ public void setMinNum(double newMinNum) { m_MinNum = newMinNum; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String minVariancePropTipText() { return "The minimum proportion of the variance on all the data " + "that needs to be present at a node in order for splitting to " + "be performed in regression trees."; } /** * Get the value of MinVarianceProp. * * @return Value of MinVarianceProp. */ public double getMinVarianceProp() { return m_MinVarianceProp; } /** * Set the value of MinVarianceProp. * * @param newMinVarianceProp Value to assign to MinVarianceProp. */ public void setMinVarianceProp(double newMinVarianceProp) { m_MinVarianceProp = newMinVarianceProp; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String seedTipText() { return "The seed used for randomizing the data."; } /** * Get the value of Seed. * * @return Value of Seed. */ public int getSeed() { return m_Seed; } /** * Set the value of Seed. * * @param newSeed Value to assign to Seed. */ public void setSeed(int newSeed) { m_Seed = newSeed; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numFoldsTipText() { return "Determines the amount of data used for pruning. One fold is used for " + "pruning, the rest for growing the rules.";
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -