📄 randomtree.java
字号:
*/
public int toGraph(StringBuffer text, int num) throws Exception {
int maxIndex = Utils.maxIndex(m_ClassProbs);
String classValue = m_Info.classAttribute().value(maxIndex);
num++;
if (m_Attribute == -1) {
text.append("N" + Integer.toHexString(hashCode()) +
" [label=\"" + num + ": " + classValue + "\"" +
"shape=box]\n");
}else {
text.append("N" + Integer.toHexString(hashCode()) +
" [label=\"" + num + ": " + classValue + "\"]\n");
for (int i = 0; i < m_Successors.length; i++) {
text.append("N" + Integer.toHexString(hashCode())
+ "->" +
"N" + Integer.toHexString(m_Successors[i].hashCode()) +
" [label=\"" + m_Info.attribute(m_Attribute).name());
if (m_Info.attribute(m_Attribute).isNumeric()) {
if (i == 0) {
text.append(" < " +
Utils.doubleToString(m_SplitPoint, 2));
} else {
text.append(" >= " +
Utils.doubleToString(m_SplitPoint, 2));
}
} else {
text.append(" = " + m_Info.attribute(m_Attribute).value(i));
}
text.append("\"]\n");
num = m_Successors[i].toGraph(text, num);
}
}
return num;
}
/**
* Outputs the decision tree.
*/
public String toString() {
if (m_Successors == null) {
return "RandomTree: no model has been built yet.";
} else {
return
"\nRandomTree\n==========\n" + toString(0) + "\n" +
"\nSize of the tree : " + numNodes();
}
}
/**
* Outputs a leaf.
*/
protected String leafString() throws Exception {
int maxIndex = Utils.maxIndex(m_Distribution[0]);
return " : " + m_Info.classAttribute().value(maxIndex) +
" (" + Utils.doubleToString(Utils.sum(m_Distribution[0]), 2) + "/" +
Utils.doubleToString((Utils.sum(m_Distribution[0]) -
m_Distribution[0][maxIndex]), 2) + ")";
}
/**
* Recursively outputs the tree.
*/
protected String toString(int level) {
try {
StringBuffer text = new StringBuffer();
if (m_Attribute == -1) {
// Output leaf info
return leafString();
} else if (m_Info.attribute(m_Attribute).isNominal()) {
// For nominal attributes
for (int i = 0; i < m_Successors.length; i++) {
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " = " +
m_Info.attribute(m_Attribute).value(i));
text.append(m_Successors[i].toString(level + 1));
}
} else {
// For numeric attributes
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " < " +
Utils.doubleToString(m_SplitPoint, 2));
text.append(m_Successors[0].toString(level + 1));
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " >= " +
Utils.doubleToString(m_SplitPoint, 2));
text.append(m_Successors[1].toString(level + 1));
}
return text.toString();
} catch (Exception e) {
e.printStackTrace();
return "RandomTree: tree can't be printed";
}
}
/**
* Recursively generates a tree.
*/
protected void buildTree(int[][] sortedIndices, double[][] weights,
Instances data, double[] classProbs,
Instances header, double minNum, boolean debug,
int[] attIndicesWindow, Random random)
throws Exception {
// Store structure of dataset, set minimum number of instances
m_Info = header;
m_Debug = debug;
m_MinNum = minNum;
// Make leaf if there are no training instances
if (sortedIndices[0].length == 0) {
m_Distribution = new double[1][data.numClasses()];
m_ClassProbs = null;
return;
}
// Check if node doesn't contain enough instances or is pure
m_ClassProbs = new double[classProbs.length];
System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
if (Utils.sm(Utils.sum(m_ClassProbs), 2 * m_MinNum) ||
Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
Utils.sum(m_ClassProbs))) {
// Make leaf
m_Attribute = -1;
m_Distribution = new double[1][m_ClassProbs.length];
for (int i = 0; i < m_ClassProbs.length; i++) {
m_Distribution[0][i] = m_ClassProbs[i];
}
Utils.normalize(m_ClassProbs);
return;
}
// Compute class distributions and value of splitting
// criterion for each attribute
double[] vals = new double[data.numAttributes()];
double[][][] dists = new double[data.numAttributes()][0][0];
double[][] props = new double[data.numAttributes()][0];
double[] splits = new double[data.numAttributes()];
// Investigate K random attributes
int attIndex = 0;
int windowSize = attIndicesWindow.length;
int k = m_KValue;
boolean gainFound = false;
while ((windowSize > 0) && (k-- > 0 || !gainFound)) {
int chosenIndex = random.nextInt(windowSize);
attIndex = attIndicesWindow[chosenIndex];
// shift chosen attIndex out of window
attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize-1];
attIndicesWindow[windowSize-1] = attIndex;
windowSize--;
splits[attIndex] = distribution(props, dists, attIndex,
sortedIndices[attIndex],
weights[attIndex], data);
vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex]));
if (Utils.gr(vals[attIndex], 0)) gainFound = true;
}
// Find best attribute
m_Attribute = Utils.maxIndex(vals);
m_Distribution = dists[m_Attribute];
// Any useful split found?
if (Utils.gr(vals[m_Attribute], 0)) {
// Build subtrees
m_SplitPoint = splits[m_Attribute];
m_Prop = props[m_Attribute];
int[][][] subsetIndices =
new int[m_Distribution.length][data.numAttributes()][0];
double[][][] subsetWeights =
new double[m_Distribution.length][data.numAttributes()][0];
splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint,
sortedIndices, weights, m_Distribution, data);
m_Successors = new RandomTree[m_Distribution.length];
for (int i = 0; i < m_Distribution.length; i++) {
m_Successors[i] = new RandomTree();
m_Successors[i].setKValue(m_KValue);
m_Successors[i].buildTree(subsetIndices[i], subsetWeights[i], data,
m_Distribution[i], header, m_MinNum, m_Debug,
attIndicesWindow, random);
}
} else {
// Make leaf
m_Attribute = -1;
m_Distribution = new double[1][m_ClassProbs.length];
for (int i = 0; i < m_ClassProbs.length; i++) {
m_Distribution[0][i] = m_ClassProbs[i];
}
}
// Normalize class counts
Utils.normalize(m_ClassProbs);
}
/**
* Computes size of the tree.
*/
public int numNodes() {
if (m_Attribute == -1) {
return 1;
} else {
int size = 1;
for (int i = 0; i < m_Successors.length; i++) {
size += m_Successors[i].numNodes();
}
return size;
}
}
/**
* Splits instances into subsets.
*/
protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
int att, double splitPoint,
int[][] sortedIndices, double[][] weights,
double[][] dist, Instances data) throws Exception {
int j;
int[] num;
// For each attribute
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
if (data.attribute(att).isNominal()) {
// For nominal attributes
num = new int[data.attribute(att).numValues()];
for (int k = 0; k < num.length; k++) {
subsetIndices[k][i] = new int[sortedIndices[i].length];
subsetWeights[k][i] = new double[sortedIndices[i].length];
}
for (j = 0; j < sortedIndices[i].length; j++) {
Instance inst = data.instance(sortedIndices[i][j]);
if (inst.isMissing(att)) {
// Split instance up
for (int k = 0; k < num.length; k++) {
if (Utils.gr(m_Prop[k], 0)) {
subsetIndices[k][i][num[k]] = sortedIndices[i][j];
subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j];
num[k]++;
}
}
} else {
int subset = (int)inst.value(att);
subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
subsetWeights[subset][i][num[subset]] = weights[i][j];
num[subset]++;
}
}
} else {
// For numeric attributes
num = new int[2];
for (int k = 0; k < 2; k++) {
subsetIndices[k][i] = new int[sortedIndices[i].length];
subsetWeights[k][i] = new double[weights[i].length];
}
for (j = 0; j < sortedIndices[i].length; j++) {
Instance inst = data.instance(sortedIndices[i][j]);
if (inst.isMissing(att)) {
// Split instance up
for (int k = 0; k < num.length; k++) {
if (Utils.gr(m_Prop[k], 0)) {
subsetIndices[k][i][num[k]] = sortedIndices[i][j];
subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j];
num[k]++;
}
}
} else {
int subset = Utils.sm(inst.value(att), splitPoint) ? 0 : 1;
subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
subsetWeights[subset][i][num[subset]] = weights[i][j];
num[subset]++;
}
}
}
// Trim arrays
for (int k = 0; k < num.length; k++) {
int[] copy = new int[num[k]];
System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
subsetIndices[k][i] = copy;
double[] copyWeights = new double[num[k]];
System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]);
subsetWeights[k][i] = copyWeights;
}
}
}
}
/**
* Computes class distribution for an attribute.
*/
protected double distribution(double[][] props, double[][][] dists, int att,
int[] sortedIndices,
double[] weights, Instances data)
throws Exception {
double splitPoint = Double.NaN;
Attribute attribute = data.attribute(att);
double[][] dist = null;
int i;
if (attribute.isNominal()) {
// For nominal attributes
dist = new double[attribute.numValues()][data.numClasses()];
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i];
}
} else {
// For numeric attributes
double[][] currDist = new double[2][data.numClasses()];
dist = new double[2][data.numClasses()];
// Move all instances into second subset
for (int j = 0; j < sortedIndices.length; j++) {
Instance inst = data.instance(sortedIndices[j]);
if (inst.isMissing(att)) {
break;
}
currDist[1][(int)inst.classValue()] += weights[j];
}
double priorVal = priorVal(currDist);
for (int j = 0; j < currDist.length; j++) {
System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
}
// Try all possible split points
double currSplit = data.instance(sortedIndices[0]).value(att);
double currVal, bestVal = -Double.MAX_VALUE;
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
if (Utils.gr(inst.value(att), currSplit)) {
currVal = gain(currDist, priorVal);
if (Utils.gr(currVal, bestVal)) {
bestVal = currVal;
splitPoint = (inst.value(att) + currSplit) / 2.0;
for (int j = 0; j < currDist.length; j++) {
System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
}
}
}
currSplit = inst.value(att);
currDist[0][(int)inst.classValue()] += weights[i];
currDist[1][(int)inst.classValue()] -= weights[i];
}
}
// Compute weights
props[att] = new double[dist.length];
for (int k = 0; k < props[att].length; k++) {
props[att][k] = Utils.sum(dist[k]);
}
if (Utils.eq(Utils.sum(props[att]), 0)) {
for (int k = 0; k < props[att].length; k++) {
props[att][k] = 1.0 / (double)props[att].length;
}
} else {
Utils.normalize(props[att]);
}
// Any instances with missing values ?
if (i < sortedIndices.length) {
// Distribute counts
while (i < sortedIndices.length) {
Instance inst = data.instance(sortedIndices[i]);
for (int j = 0; j < dist.length; j++) {
dist[j][(int)inst.classValue()] += props[att][j] * weights[i];
}
i++;
}
}
// Return distribution and split point
dists[att] = dist;
return splitPoint;
}
/**
* Computes value of splitting criterion before split.
*/
protected double priorVal(double[][] dist) {
return ContingencyTables.entropyOverColumns(dist);
}
/**
* Computes value of splitting criterion after split.
*/
protected double gain(double[][] dist, double priorVal) {
return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
}
/**
* Main method for this class.
*/
public static void main(String[] argv) {
try {
System.out.println(Evaluation.evaluateModel(new RandomTree(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -