📄 decisionstump.java
字号:
double[] sumCounts = new double[m_Instances.numClasses()];
double[][] bestDist = new double[3][m_Instances.numClasses()];
int numMissing = 0;
// Compute counts for all the values
for (int i = 0; i < m_Instances.numInstances(); i++) {
Instance inst = m_Instances.instance(i);
if (inst.isMissing(index)) {
numMissing++;
counts[m_Instances.attribute(index).numValues()]
[(int)inst.classValue()] += inst.weight();
} else {
counts[(int)inst.value(index)][(int)inst.classValue()] += inst
.weight();
}
}
// Compute sum of counts
for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
for (int j = 0; j < m_Instances.numClasses(); j++) {
sumCounts[j] += counts[i][j];
}
}
// Make split counts for each possible split and evaluate
System.arraycopy(counts[m_Instances.attribute(index).numValues()], 0,
m_Distribution[2], 0, m_Instances.numClasses());
for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
for (int j = 0; j < m_Instances.numClasses(); j++) {
m_Distribution[0][j] = counts[i][j];
m_Distribution[1][j] = sumCounts[j] - counts[i][j];
}
currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
if (currVal < bestVal) {
bestVal = currVal;
m_SplitPoint = (double)i;
for (int j = 0; j < 3; j++) {
System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
m_Instances.numClasses());
}
}
}
// No missing values in training data.
if (numMissing == 0) {
System.arraycopy(sumCounts, 0, bestDist[2], 0,
m_Instances.numClasses());
}
m_Distribution = bestDist;
return bestVal;
}
/**
* Finds best split for nominal attribute and numeric class
* and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @exception Exception if something goes wrong
*/
private double findSplitNominalNumeric(int index) throws Exception {
double bestVal = Double.MAX_VALUE, currVal;
double[] sumsSquaresPerValue =
new double[m_Instances.attribute(index).numValues()],
sumsPerValue = new double[m_Instances.attribute(index).numValues()],
weightsPerValue = new double[m_Instances.attribute(index).numValues()];
double totalSumSquaresW = 0, totalSumW = 0, totalSumOfWeightsW = 0,
totalSumOfWeights = 0, totalSum = 0;
double[] sumsSquares = new double[3], sumOfWeights = new double[3];
double[][] bestDist = new double[3][1];
// Compute counts for all the values
for (int i = 0; i < m_Instances.numInstances(); i++) {
Instance inst = m_Instances.instance(i);
if (inst.isMissing(index)) {
m_Distribution[2][0] += inst.classValue() * inst.weight();
sumsSquares[2] += inst.classValue() * inst.classValue()
* inst.weight();
sumOfWeights[2] += inst.weight();
} else {
weightsPerValue[(int)inst.value(index)] += inst.weight();
sumsPerValue[(int)inst.value(index)] += inst.classValue()
* inst.weight();
sumsSquaresPerValue[(int)inst.value(index)] +=
inst.classValue() * inst.classValue() * inst.weight();
}
totalSumOfWeights += inst.weight();
totalSum += inst.classValue() * inst.weight();
}
// Check if the total weight is zero
if (totalSumOfWeights <= 0) {
return bestVal;
}
// Compute sum of counts without missing ones
for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
totalSumOfWeightsW += weightsPerValue[i];
totalSumSquaresW += sumsSquaresPerValue[i];
totalSumW += sumsPerValue[i];
}
// Make split counts for each possible split and evaluate
for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
m_Distribution[0][0] = sumsPerValue[i];
sumsSquares[0] = sumsSquaresPerValue[i];
sumOfWeights[0] = weightsPerValue[i];
m_Distribution[1][0] = totalSumW - sumsPerValue[i];
sumsSquares[1] = totalSumSquaresW - sumsSquaresPerValue[i];
sumOfWeights[1] = totalSumOfWeightsW - weightsPerValue[i];
currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
if (currVal < bestVal) {
bestVal = currVal;
m_SplitPoint = (double)i;
for (int j = 0; j < 3; j++) {
if (sumOfWeights[j] > 0) {
bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
} else {
bestDist[j][0] = totalSum / totalSumOfWeights;
}
}
}
}
m_Distribution = bestDist;
return bestVal;
}
/**
* Finds best split for numeric attribute and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @exception Exception if something goes wrong
*/
private double findSplitNumeric(int index) throws Exception {
if (m_Instances.classAttribute().isNominal()) {
return findSplitNumericNominal(index);
} else {
return findSplitNumericNumeric(index);
}
}
/**
* Finds best split for numeric attribute and nominal class
* and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @exception Exception if something goes wrong
*/
private double findSplitNumericNominal(int index) throws Exception {
double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
int numMissing = 0;
double[] sum = new double[m_Instances.numClasses()];
double[][] bestDist = new double[3][m_Instances.numClasses()];
// Compute counts for all the values
for (int i = 0; i < m_Instances.numInstances(); i++) {
Instance inst = m_Instances.instance(i);
if (!inst.isMissing(index)) {
m_Distribution[1][(int)inst.classValue()] += inst.weight();
} else {
m_Distribution[2][(int)inst.classValue()] += inst.weight();
numMissing++;
}
}
System.arraycopy(m_Distribution[1], 0, sum, 0, m_Instances.numClasses());
// Save current distribution as best distribution
for (int j = 0; j < 3; j++) {
System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
m_Instances.numClasses());
}
// Sort instances
m_Instances.sort(index);
// Make split counts for each possible split and evaluate
for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
Instance inst = m_Instances.instance(i);
Instance instPlusOne = m_Instances.instance(i + 1);
m_Distribution[0][(int)inst.classValue()] += inst.weight();
m_Distribution[1][(int)inst.classValue()] -= inst.weight();
if (inst.value(index) < instPlusOne.value(index)) {
currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
if (currVal < bestVal) {
m_SplitPoint = currCutPoint;
bestVal = currVal;
for (int j = 0; j < 3; j++) {
System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
m_Instances.numClasses());
}
}
}
}
// No missing values in training data.
if (numMissing == 0) {
System.arraycopy(sum, 0, bestDist[2], 0, m_Instances.numClasses());
}
m_Distribution = bestDist;
return bestVal;
}
/**
* Finds best split for numeric attribute and numeric class
* and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @exception Exception if something goes wrong
*/
private double findSplitNumericNumeric(int index) throws Exception {
double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
int numMissing = 0;
double[] sumsSquares = new double[3], sumOfWeights = new double[3];
double[][] bestDist = new double[3][1];
double totalSum = 0, totalSumOfWeights = 0;
// Compute counts for all the values
for (int i = 0; i < m_Instances.numInstances(); i++) {
Instance inst = m_Instances.instance(i);
if (!inst.isMissing(index)) {
m_Distribution[1][0] += inst.classValue() * inst.weight();
sumsSquares[1] += inst.classValue() * inst.classValue()
* inst.weight();
sumOfWeights[1] += inst.weight();
} else {
m_Distribution[2][0] += inst.classValue() * inst.weight();
sumsSquares[2] += inst.classValue() * inst.classValue()
* inst.weight();
sumOfWeights[2] += inst.weight();
numMissing++;
}
totalSumOfWeights += inst.weight();
totalSum += inst.classValue() * inst.weight();
}
// Check if the total weight is zero
if (totalSumOfWeights <= 0) {
return bestVal;
}
// Sort instances
m_Instances.sort(index);
// Make split counts for each possible split and evaluate
for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
Instance inst = m_Instances.instance(i);
Instance instPlusOne = m_Instances.instance(i + 1);
m_Distribution[0][0] += inst.classValue() * inst.weight();
sumsSquares[0] += inst.classValue() * inst.classValue() * inst.weight();
sumOfWeights[0] += inst.weight();
m_Distribution[1][0] -= inst.classValue() * inst.weight();
sumsSquares[1] -= inst.classValue() * inst.classValue() * inst.weight();
sumOfWeights[1] -= inst.weight();
if (inst.value(index) < instPlusOne.value(index)) {
currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
if (currVal < bestVal) {
m_SplitPoint = currCutPoint;
bestVal = currVal;
for (int j = 0; j < 3; j++) {
if (sumOfWeights[j] > 0) {
bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
} else {
bestDist[j][0] = totalSum / totalSumOfWeights;
}
}
}
}
}
m_Distribution = bestDist;
return bestVal;
}
/**
* Computes variance for subsets.
*/
private double variance(double[][] s,double[] sS,double[] sumOfWeights) {
double var = 0;
for (int i = 0; i < s.length; i++) {
if (sumOfWeights[i] > 0) {
var += sS[i] - ((s[i][0] * s[i][0]) / (double) sumOfWeights[i]);
}
}
return var;
}
/**
* Returns the subset an instance falls into.
*/
private int whichSubset(Instance instance) throws Exception {
if (instance.isMissing(m_AttIndex)) {
return 2;
} else if (instance.attribute(m_AttIndex).isNominal()) {
if ((int)instance.value(m_AttIndex) == m_SplitPoint) {
return 0;
} else {
return 1;
}
} else {
if (instance.value(m_AttIndex) <= m_SplitPoint) {
return 0;
} else {
return 1;
}
}
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
Classifier scheme;
try {
scheme = new DecisionStump();
System.out.println(Evaluation.evaluateModel(scheme, argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -