📄 reptree.java
字号:
totalSumOfWeights = currSumOfWeights[1];
sums[1] = currSums[1];
sumSquared[1] = currSumSquared[1];
sumOfWeights[1] = currSumOfWeights[1];
// Try all possible split points
double currSplit = data.instance(sortedIndices[0]).value(att);
double currVal, bestVal = Double.MAX_VALUE;
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
if (inst.value(att) > currSplit) {
currVal = variance(currSums, currSumSquared, currSumOfWeights);
if (currVal < bestVal) {
bestVal = currVal;
splitPoint = (inst.value(att) + currSplit) / 2.0;
for (int j = 0; j < 2; j++) {
sums[j] = currSums[j];
sumSquared[j] = currSumSquared[j];
sumOfWeights[j] = currSumOfWeights[j];
}
}
}
currSplit = inst.value(att);
double classVal = inst.classValue() * weights[i];
double classValSquared = inst.classValue() * classVal;
currSums[0] += classVal;
currSumSquared[0] += classValSquared;
currSumOfWeights[0] += weights[i];
currSums[1] -= classVal;
currSumSquared[1] -= classValSquared;
currSumOfWeights[1] -= weights[i];
}
}
// Compute weights
props[att] = new double[sums.length];
for (int k = 0; k < props[att].length; k++) {
props[att][k] = sumOfWeights[k];
}
if (!(Utils.sum(props[att]) > 0)) {
for (int k = 0; k < props[att].length; k++) {
props[att][k] = 1.0 / (double)props[att].length;
}
} else {
Utils.normalize(props[att]);
}
// Distribute counts for missing values
while (i < sortedIndices.length) {
Instance inst = data.instance(sortedIndices[i]);
for (int j = 0; j < sums.length; j++) {
sums[j] += props[att][j] * inst.classValue() * weights[i];
sumSquared[j] += props[att][j] * inst.classValue() *
inst.classValue() * weights[i];
sumOfWeights[j] += props[att][j] * weights[i];
}
totalSum += inst.classValue() * weights[i];
totalSumSquared +=
inst.classValue() * inst.classValue() * weights[i];
totalSumOfWeights += weights[i];
i++;
}
// Compute final distribution
dist = new double[sums.length][data.numClasses()];
for (int j = 0; j < sums.length; j++) {
if (sumOfWeights[j] > 0) {
dist[j][0] = sums[j] / sumOfWeights[j];
} else {
dist[j][0] = totalSum / totalSumOfWeights;
}
}
// Compute variance gain
double priorVar =
singleVariance(totalSum, totalSumSquared, totalSumOfWeights);
double var = variance(sums, sumSquared, sumOfWeights);
double gain = priorVar - var;
// Return distribution and split point
subsetWeights[att] = sumOfWeights;
dists[att] = dist;
vals[att] = gain;
return splitPoint;
}
/**
* Computes variance for subsets.
*/
protected double variance(double[] s, double[] sS,
double[] sumOfWeights) {
double var = 0;
for (int i = 0; i < s.length; i++) {
if (sumOfWeights[i] > 0) {
var += singleVariance(s[i], sS[i], sumOfWeights[i]);
}
}
return var;
}
/**
* Computes the variance for a single set
*/
protected double singleVariance(double s, double sS, double weight) {
return sS - ((s * s) / weight);
}
/**
* Computes value of splitting criterion before split.
*/
protected double priorVal(double[][] dist) {
return ContingencyTables.entropyOverColumns(dist);
}
/**
* Computes value of splitting criterion after split.
*/
protected double gain(double[][] dist, double priorVal) {
return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
}
/**
* Prunes the tree using the hold-out data (bottom-up).
*/
protected double reducedErrorPrune() throws Exception {
// Is node leaf ?
if (m_Attribute == -1) {
return m_HoldOutError;
}
// Prune all sub trees
double errorTree = 0;
for (int i = 0; i < m_Successors.length; i++) {
errorTree += m_Successors[i].reducedErrorPrune();
}
// Replace sub tree with leaf if error doesn't get worse
if (errorTree >= m_HoldOutError) {
m_Attribute = -1;
m_Successors = null;
return m_HoldOutError;
} else {
return errorTree;
}
}
/**
* Inserts hold-out set into tree.
*/
protected void insertHoldOutSet(Instances data) throws Exception {
for (int i = 0; i < data.numInstances(); i++) {
insertHoldOutInstance(data.instance(i), data.instance(i).weight(),
this);
}
}
/**
* Inserts an instance from the hold-out set into the tree.
*/
protected void insertHoldOutInstance(Instance inst, double weight,
Tree parent) throws Exception {
// Insert instance into hold-out class distribution
if (inst.classAttribute().isNominal()) {
// Nominal case
m_HoldOutDist[(int)inst.classValue()] += weight;
int predictedClass = 0;
if (m_ClassProbs == null) {
predictedClass = Utils.maxIndex(parent.m_ClassProbs);
} else {
predictedClass = Utils.maxIndex(m_ClassProbs);
}
if (predictedClass != (int)inst.classValue()) {
m_HoldOutError += weight;
}
} else {
// Numeric case
m_HoldOutDist[0] += weight;
double diff = 0;
if (m_ClassProbs == null) {
diff = parent.m_ClassProbs[0] - inst.classValue();
} else {
diff = m_ClassProbs[0] - inst.classValue();
}
m_HoldOutError += diff * diff * weight;
}
// The process is recursive
if (m_Attribute != -1) {
// If node is not a leaf
if (inst.isMissing(m_Attribute)) {
// Distribute instance
for (int i = 0; i < m_Successors.length; i++) {
if (m_Prop[i] > 0) {
m_Successors[i].insertHoldOutInstance(inst, weight *
m_Prop[i], this);
}
}
} else {
if (m_Info.attribute(m_Attribute).isNominal()) {
// Treat nominal attributes
m_Successors[(int)inst.value(m_Attribute)].
insertHoldOutInstance(inst, weight, this);
} else {
// Treat numeric attributes
if (inst.value(m_Attribute) < m_SplitPoint) {
m_Successors[0].insertHoldOutInstance(inst, weight, this);
} else {
m_Successors[1].insertHoldOutInstance(inst, weight, this);
}
}
}
}
}
/**
* Inserts hold-out set into tree.
*/
protected void backfitHoldOutSet(Instances data) throws Exception {
for (int i = 0; i < data.numInstances(); i++) {
backfitHoldOutInstance(data.instance(i), data.instance(i).weight(),
this);
}
}
/**
* Inserts an instance from the hold-out set into the tree.
*/
protected void backfitHoldOutInstance(Instance inst, double weight,
Tree parent) throws Exception {
// Insert instance into hold-out class distribution
if (inst.classAttribute().isNominal()) {
// Nominal case
if (m_ClassProbs == null) {
m_ClassProbs = new double[inst.numClasses()];
}
System.arraycopy(m_Distribution, 0, m_ClassProbs, 0, inst.numClasses());
m_ClassProbs[(int)inst.classValue()] += weight;
Utils.normalize(m_ClassProbs);
} else {
// Numeric case
if (m_ClassProbs == null) {
m_ClassProbs = new double[1];
}
m_ClassProbs[0] *= m_Distribution[1];
m_ClassProbs[0] += weight * inst.classValue();
m_ClassProbs[0] /= (m_Distribution[1] + weight);
}
// The process is recursive
if (m_Attribute != -1) {
// If node is not a leaf
if (inst.isMissing(m_Attribute)) {
// Distribute instance
for (int i = 0; i < m_Successors.length; i++) {
if (m_Prop[i] > 0) {
m_Successors[i].backfitHoldOutInstance(inst, weight *
m_Prop[i], this);
}
}
} else {
if (m_Info.attribute(m_Attribute).isNominal()) {
// Treat nominal attributes
m_Successors[(int)inst.value(m_Attribute)].
backfitHoldOutInstance(inst, weight, this);
} else {
// Treat numeric attributes
if (inst.value(m_Attribute) < m_SplitPoint) {
m_Successors[0].backfitHoldOutInstance(inst, weight, this);
} else {
m_Successors[1].backfitHoldOutInstance(inst, weight, this);
}
}
}
}
}
}
/** The Tree object */
protected Tree m_Tree = null;
/** Number of folds for reduced error pruning. */
protected int m_NumFolds = 3;
/** Seed for random data shuffling. */
protected int m_Seed = 1;
/** Don't prune */
protected boolean m_NoPruning = false;
/** The minimum number of instances per leaf. */
protected double m_MinNum = 2;
/** The minimum proportion of the total variance (over all the data)
required for split. */
protected double m_MinVarianceProp = 1e-3;
/** Upper bound on the tree depth */
protected int m_MaxDepth = -1;
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String noPruningTipText() {
return "Whether pruning is performed.";
}
/**
* Get the value of NoPruning.
*
* @return Value of NoPruning.
*/
public boolean getNoPruning() {
return m_NoPruning;
}
/**
* Set the value of NoPruning.
*
* @param newNoPruning Value to assign to NoPruning.
*/
public void setNoPruning(boolean newNoPruning) {
m_NoPruning = newNoPruning;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minNumTipText() {
return "The minimum total weight of the instances in a leaf.";
}
/**
* Get the value of MinNum.
*
* @return Value of MinNum.
*/
public double getMinNum() {
return m_MinNum;
}
/**
* Set the value of MinNum.
*
* @param newMinNum Value to assign to MinNum.
*/
public void setMinNum(double newMinNum) {
m_MinNum = newMinNum;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minVariancePropTipText() {
return "The minimum proportion of the variance on all the data " +
"that needs to be present at a node in order for splitting to " +
"be performed in regression trees.";
}
/**
* Get the value of MinVarianceProp.
*
* @return Value of MinVarianceProp.
*/
public double getMinVarianceProp() {
return m_MinVarianceProp;
}
/**
* Set the value of MinVarianceProp.
*
* @param newMinVarianceProp Value to assign to MinVarianceProp.
*/
public void setMinVarianceProp(double newMinVarianceProp) {
m_MinVarianceProp = newMinVarianceProp;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String seedTipText() {
return "The seed used for randomizing the data.";
}
/**
* Get the value of Seed.
*
* @return Value of Seed.
*/
public int getSeed() {
return m_Seed;
}
/**
* Set the value of Seed.
*
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -