📄 reptree.java
字号:
text.append(m_Info.attribute(m_Attribute).name() + " >= " +
Utils.doubleToString(m_SplitPoint, 2));
text.append(m_Successors[1].toString(level + 1, this));
}
return text.toString();
} catch (Exception e) {
e.printStackTrace();
return "Decision tree: tree can't be printed";
}
}
/**
* Recursively generates a tree.
*/
protected void buildTree(int[][] sortedIndices, double[][] weights,
Instances data, double totalWeight,
double[] classProbs, Instances header,
double minNum, double minVariance,
int depth, int maxDepth)
throws Exception {
// Store structure of dataset, set minimum number of instances
// and make space for potential info from pruning data
m_Info = header;
m_HoldOutDist = new double[data.numClasses()];
// Make leaf if there are no training instances
int helpIndex = 0;
if (data.classIndex() == 0) {
helpIndex = 1;
}
if (sortedIndices[helpIndex].length == 0) {
if (data.classAttribute().isNumeric()) {
m_Distribution = new double[2];
} else {
m_Distribution = new double[data.numClasses()];
}
m_ClassProbs = null;
return;
}
double priorVar = 0;
if (data.classAttribute().isNumeric()) {
// Compute prior variance
double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;
for (int i = 0; i < sortedIndices[helpIndex].length; i++) {
Instance inst = data.instance(sortedIndices[helpIndex][i]);
totalSum += inst.classValue() * weights[helpIndex][i];
totalSumSquared +=
inst.classValue() * inst.classValue() * weights[helpIndex][i];
totalSumOfWeights += weights[helpIndex][i];
}
priorVar = singleVariance(totalSum, totalSumSquared,
totalSumOfWeights);
}
// Check if node doesn't contain enough instances, is pure
// or the maximum tree depth is reached
m_ClassProbs = new double[classProbs.length];
System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
if ((totalWeight < (2 * minNum)) ||
// Nominal case
(data.classAttribute().isNominal() &&
Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
Utils.sum(m_ClassProbs))) ||
// Numeric case
(data.classAttribute().isNumeric() &&
((priorVar / totalWeight) < minVariance)) ||
// Check tree depth
((m_MaxDepth >= 0) && (depth >= maxDepth))) {
// Make leaf
m_Attribute = -1;
if (data.classAttribute().isNominal()) {
// Nominal case
m_Distribution = new double[m_ClassProbs.length];
for (int i = 0; i < m_ClassProbs.length; i++) {
m_Distribution[i] = m_ClassProbs[i];
}
Utils.normalize(m_ClassProbs);
} else {
// Numeric case
m_Distribution = new double[2];
m_Distribution[0] = priorVar;
m_Distribution[1] = totalWeight;
}
return;
}
// Compute class distributions and value of splitting
// criterion for each attribute
double[] vals = new double[data.numAttributes()];
double[][][] dists = new double[data.numAttributes()][0][0];
double[][] props = new double[data.numAttributes()][0];
double[][] totalSubsetWeights = new double[data.numAttributes()][0];
double[] splits = new double[data.numAttributes()];
if (data.classAttribute().isNominal()) {
// Nominal case
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
splits[i] = distribution(props, dists, i, sortedIndices[i],
weights[i], totalSubsetWeights, data);
vals[i] = gain(dists[i], priorVal(dists[i]));
}
}
} else {
// Numeric case
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
splits[i] =
numericDistribution(props, dists, i, sortedIndices[i],
weights[i], totalSubsetWeights, data,
vals);
}
}
}
// Find best attribute
m_Attribute = Utils.maxIndex(vals);
int numAttVals = dists[m_Attribute].length;
// Check if there are at least two subsets with
// required minimum number of instances
int count = 0;
for (int i = 0; i < numAttVals; i++) {
if (totalSubsetWeights[m_Attribute][i] >= minNum) {
count++;
}
if (count > 1) {
break;
}
}
// Any useful split found?
if ((vals[m_Attribute] > 0) && (count > 1)) {
// Build subtrees
m_SplitPoint = splits[m_Attribute];
m_Prop = props[m_Attribute];
int[][][] subsetIndices =
new int[numAttVals][data.numAttributes()][0];
double[][][] subsetWeights =
new double[numAttVals][data.numAttributes()][0];
splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint,
sortedIndices, weights, data);
m_Successors = new Tree[numAttVals];
for (int i = 0; i < numAttVals; i++) {
m_Successors[i] = new Tree();
m_Successors[i].
buildTree(subsetIndices[i], subsetWeights[i],
data, totalSubsetWeights[m_Attribute][i],
dists[m_Attribute][i], header, minNum,
minVariance, depth + 1, maxDepth);
}
} else {
// Make leaf
m_Attribute = -1;
}
// Normalize class counts
if (data.classAttribute().isNominal()) {
m_Distribution = new double[m_ClassProbs.length];
for (int i = 0; i < m_ClassProbs.length; i++) {
m_Distribution[i] = m_ClassProbs[i];
}
Utils.normalize(m_ClassProbs);
} else {
m_Distribution = new double[2];
m_Distribution[0] = priorVar;
m_Distribution[1] = totalWeight;
}
}
/**
* Computes size of the tree.
*/
protected int numNodes() {
if (m_Attribute == -1) {
return 1;
} else {
int size = 1;
for (int i = 0; i < m_Successors.length; i++) {
size += m_Successors[i].numNodes();
}
return size;
}
}
/**
* Splits instances into subsets.
*/
protected void splitData(int[][][] subsetIndices,
double[][][] subsetWeights,
int att, double splitPoint,
int[][] sortedIndices, double[][] weights,
Instances data) throws Exception {
int j;
int[] num;
// For each attribute
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
if (data.attribute(att).isNominal()) {
// For nominal attributes
num = new int[data.attribute(att).numValues()];
for (int k = 0; k < num.length; k++) {
subsetIndices[k][i] = new int[sortedIndices[i].length];
subsetWeights[k][i] = new double[sortedIndices[i].length];
}
for (j = 0; j < sortedIndices[i].length; j++) {
Instance inst = data.instance(sortedIndices[i][j]);
if (inst.isMissing(att)) {
// Split instance up
for (int k = 0; k < num.length; k++) {
if (m_Prop[k] > 0) {
subsetIndices[k][i][num[k]] = sortedIndices[i][j];
subsetWeights[k][i][num[k]] =
m_Prop[k] * weights[i][j];
num[k]++;
}
}
} else {
int subset = (int)inst.value(att);
subsetIndices[subset][i][num[subset]] =
sortedIndices[i][j];
subsetWeights[subset][i][num[subset]] = weights[i][j];
num[subset]++;
}
}
} else {
// For numeric attributes
num = new int[2];
for (int k = 0; k < 2; k++) {
subsetIndices[k][i] = new int[sortedIndices[i].length];
subsetWeights[k][i] = new double[weights[i].length];
}
for (j = 0; j < sortedIndices[i].length; j++) {
Instance inst = data.instance(sortedIndices[i][j]);
if (inst.isMissing(att)) {
// Split instance up
for (int k = 0; k < num.length; k++) {
if (m_Prop[k] > 0) {
subsetIndices[k][i][num[k]] = sortedIndices[i][j];
subsetWeights[k][i][num[k]] =
m_Prop[k] * weights[i][j];
num[k]++;
}
}
} else {
int subset = (inst.value(att) < splitPoint) ? 0 : 1;
subsetIndices[subset][i][num[subset]] =
sortedIndices[i][j];
subsetWeights[subset][i][num[subset]] = weights[i][j];
num[subset]++;
}
}
}
// Trim arrays
for (int k = 0; k < num.length; k++) {
int[] copy = new int[num[k]];
System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
subsetIndices[k][i] = copy;
double[] copyWeights = new double[num[k]];
System.arraycopy(subsetWeights[k][i], 0,
copyWeights, 0, num[k]);
subsetWeights[k][i] = copyWeights;
}
}
}
}
/**
* Computes class distribution for an attribute.
*/
protected double distribution(double[][] props,
double[][][] dists, int att,
int[] sortedIndices,
double[] weights,
double[][] subsetWeights,
Instances data)
throws Exception {
double splitPoint = Double.NaN;
Attribute attribute = data.attribute(att);
double[][] dist = null;
int i;
if (attribute.isNominal()) {
// For nominal attributes
dist = new double[attribute.numValues()][data.numClasses()];
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i];
}
} else {
// For numeric attributes
double[][] currDist = new double[2][data.numClasses()];
dist = new double[2][data.numClasses()];
// Move all instances into second subset
for (int j = 0; j < sortedIndices.length; j++) {
Instance inst = data.instance(sortedIndices[j]);
if (inst.isMissing(att)) {
break;
}
currDist[1][(int)inst.classValue()] += weights[j];
}
double priorVal = priorVal(currDist);
System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
// Try all possible split points
double currSplit = data.instance(sortedIndices[0]).value(att);
double currVal, bestVal = -Double.MAX_VALUE;
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
if (inst.value(att) > currSplit) {
currVal = gain(currDist, priorVal);
if (currVal > bestVal) {
bestVal = currVal;
splitPoint = (inst.value(att) + currSplit) / 2.0;
for (int j = 0; j < currDist.length; j++) {
System.arraycopy(currDist[j], 0, dist[j], 0,
dist[j].length);
}
}
}
currSplit = inst.value(att);
currDist[0][(int)inst.classValue()] += weights[i];
currDist[1][(int)inst.classValue()] -= weights[i];
}
}
// Compute weights
props[att] = new double[dist.length];
for (int k = 0; k < props[att].length; k++) {
props[att][k] = Utils.sum(dist[k]);
}
if (!(Utils.sum(props[att]) > 0)) {
for (int k = 0; k < props[att].length; k++) {
props[att][k] = 1.0 / (double)props[att].length;
}
} else {
Utils.normalize(props[att]);
}
// Distribute counts
while (i < sortedIndices.length) {
Instance inst = data.instance(sortedIndices[i]);
for (int j = 0; j < dist.length; j++) {
dist[j][(int)inst.classValue()] += props[att][j] * weights[i];
}
i++;
}
// Compute subset weights
subsetWeights[att] = new double[dist.length];
for (int j = 0; j < dist.length; j++) {
subsetWeights[att][j] += Utils.sum(dist[j]);
}
// Return distribution and split point
dists[att] = dist;
return splitPoint;
}
/**
* Computes class distribution for an attribute.
*/
protected double numericDistribution(double[][] props,
double[][][] dists, int att,
int[] sortedIndices,
double[] weights,
double[][] subsetWeights,
Instances data,
double[] vals)
throws Exception {
double splitPoint = Double.NaN;
Attribute attribute = data.attribute(att);
double[][] dist = null;
double[] sums = null;
double[] sumSquared = null;
double[] sumOfWeights = null;
double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;
int i;
if (attribute.isNominal()) {
// For nominal attributes
sums = new double[attribute.numValues()];
sumSquared = new double[attribute.numValues()];
sumOfWeights = new double[attribute.numValues()];
int attVal;
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
attVal = (int)inst.value(att);
sums[attVal] += inst.classValue() * weights[i];
sumSquared[attVal] +=
inst.classValue() * inst.classValue() * weights[i];
sumOfWeights[attVal] += weights[i];
}
totalSum = Utils.sum(sums);
totalSumSquared = Utils.sum(sumSquared);
totalSumOfWeights = Utils.sum(sumOfWeights);
} else {
// For numeric attributes
sums = new double[2];
sumSquared = new double[2];
sumOfWeights = new double[2];
double[] currSums = new double[2];
double[] currSumSquared = new double[2];
double[] currSumOfWeights = new double[2];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -