📄 adtree.java
字号:
m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();
// evaluate static splitters (nominal)
for (int i=0; i<m_nominalAttIndices.length; i++)
evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode,
posInstances, negInstances);
// evaluate dynamic splitters (numeric)
if (m_numericAttIndices.length > 0) {
// merge the two sets of instances into one
Instances allInstances = new Instances(posInstances);
for (Enumeration e = negInstances.emerateInstances(); e.hasMoreElements(); )
allInstances.add((Instance) e.nextElement());
// use method of finding the optimal Z split-point
for (int i=0; i<m_numericAttIndices.length; i++)
evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode,
posInstances, negInstances, allInstances);
}
if (currentNode.getChildren().size() == 0) return;
// keep searching
switch (m_searchPath) {
case SEARCHPATH_ALL:
goDownAllPathsSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_HEAVIEST:
goDownHeaviestPathSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_ZPURE:
goDownZpurePathSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_RANDOM:
goDownRandomPathSingle(currentNode, posInstances, negInstances);
break;
}
}
/**
* Continues single (two-class optimized) search by investigating every node in the
* subtree under currentNode.
*
* @param currentNode the root of the subtree to be searched
* @param posInstances the positive-class instances that apply at this node
* @param negInstances the negative-class instances that apply at this node
* @exception Exception if search fails
*/
private void goDownAllPathsSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i<split.getNumOfBranches(); i++)
searchForBestTestSingle(split.getChildForBranch(i),
split.instancesDownBranch(i, posInstances),
split.instancesDownBranch(i, negInstances));
}
}
/**
* Continues single (two-class optimized) search by investigating only the path
* with the most heavily weighted instances.
*
* @param currentNode the root of the subtree to be searched
* @param posInstances the positive-class instances that apply at this node
* @param negInstances the negative-class instances that apply at this node
* @exception Exception if search fails
*/
private void goDownHeaviestPathSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
Splitter heaviestSplit = null;
int heaviestBranch = 0;
double largestWeight = 0.0;
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i<split.getNumOfBranches(); i++) {
double weight =
split.instancesDownBranch(i, posInstances).sumOfWeights() +
split.instancesDownBranch(i, negInstances).sumOfWeights();
if (weight > largestWeight) {
heaviestSplit = split;
heaviestBranch = i;
largestWeight = weight;
}
}
}
if (heaviestSplit != null)
searchForBestTestSingle(heaviestSplit.getChildForBranch(heaviestBranch),
heaviestSplit.instancesDownBranch(heaviestBranch,
posInstances),
heaviestSplit.instancesDownBranch(heaviestBranch,
negInstances));
}
/**
* Continues single (two-class optimized) search by investigating only the path
* with the best Z-pure value at each branch.
*
* @param currentNode the root of the subtree to be searched
* @param posInstances the positive-class instances that apply at this node
* @param negInstances the negative-class instances that apply at this node
* @exception Exception if search fails
*/
private void goDownZpurePathSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
double lowestZpure = m_search_smallestZ; // do z-pure cutoff
PredictionNode bestPath = null;
Instances bestPosSplit = null, bestNegSplit = null;
// search for branch with lowest Z-pure
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i<split.getNumOfBranches(); i++) {
Instances posSplit = split.instancesDownBranch(i, posInstances);
Instances negSplit = split.instancesDownBranch(i, negInstances);
double newZpure = calcZpure(posSplit, negSplit);
if (newZpure < lowestZpure) {
lowestZpure = newZpure;
bestPath = split.getChildForBranch(i);
bestPosSplit = posSplit;
bestNegSplit = negSplit;
}
}
}
if (bestPath != null)
searchForBestTestSingle(bestPath, bestPosSplit, bestNegSplit);
}
/**
* Continues single (two-class optimized) search by investigating a random path.
*
* @param currentNode the root of the subtree to be searched
* @param posInstances the positive-class instances that apply at this node
* @param negInstances the negative-class instances that apply at this node
* @exception Exception if search fails
*/
private void goDownRandomPathSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
FastVector children = currentNode.getChildren();
Splitter split = (Splitter) children.elementAt(getRandom(children.size()));
int branch = getRandom(split.getNumOfBranches());
searchForBestTestSingle(split.getChildForBranch(branch),
split.instancesDownBranch(branch, posInstances),
split.instancesDownBranch(branch, negInstances));
}
/**
* Investigates the option of introducing a nominal split under currentNode. If it
* finds a split that has a Z-value lower than has already been found it will
* update the search information to record this as the best option so far.
*
* @param attIndex index of a nominal attribute to create a split from
* @param currentNode the parent under which a split is to be considered
* @param posInstances the positive-class instances that apply at this node
* @param negInstances the negative-class instances that apply at this node
*/
private void evaluateNominalSplitSingle(int attIndex, PredictionNode currentNode,
Instances posInstances, Instances negInstances)
{
double[] indexAndZ = findLowestZNominalSplit(posInstances, negInstances, attIndex);
if (indexAndZ[1] < m_search_smallestZ) {
m_search_smallestZ = indexAndZ[1];
m_search_bestInsertionNode = currentNode;
m_search_bestSplitter = new TwoWayNominalSplit(attIndex, (int) indexAndZ[0]);
m_search_bestPathPosInstances = posInstances;
m_search_bestPathNegInstances = negInstances;
}
}
/**
* Investigates the option of introducing a two-way numeric split under currentNode.
* If it finds a split that has a Z-value lower than has already been found it will
* update the search information to record this as the best option so far.
*
* @param attIndex index of a numeric attribute to create a split from
* @param currentNode the parent under which a split is to be considered
* @param posInstances the positive-class instances that apply at this node
* @param negInstances the negative-class instances that apply at this node
* @param allInstances all of the instances the apply at this node (pos+neg combined)
*/
private void evaluateNumericSplitSingle(int attIndex, PredictionNode currentNode,
Instances posInstances, Instances negInstances,
Instances allInstances)
throws Exception {
double[] splitAndZ = findLowestZNumericSplit(allInstances, attIndex);
if (splitAndZ[1] < m_search_smallestZ) {
m_search_smallestZ = splitAndZ[1];
m_search_bestInsertionNode = currentNode;
m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndZ[0]);
m_search_bestPathPosInstances = posInstances;
m_search_bestPathNegInstances = negInstances;
}
}
/**
* Calculates the prediction value used for a particular set of instances.
*
* @param posInstances the positive-class instances
* @param negInstances the negative-class instances
* @return the prediction value
*/
private double calcPredictionValue(Instances posInstances, Instances negInstances) {
return 0.5 * Math.log( (posInstances.sumOfWeights() + 1.0)
/ (negInstances.sumOfWeights() + 1.0) );
}
/**
* Calculates the Z-pure value for a particular set of instances.
*
* @param posInstances the positive-class instances
* @param negInstances the negative-class instances
* @return the Z-pure value
*/
private double calcZpure(Instances posInstances, Instances negInstances) {
double posWeight = posInstances.sumOfWeights();
double negWeight = negInstances.sumOfWeights();
return (2.0 * (Math.sqrt(posWeight+1.0) + Math.sqrt(negWeight+1.0))) +
(m_trainTotalWeight - (posWeight + negWeight));
}
/**
* Updates the weights of instances that are influenced by a new prediction value.
*
* @param posInstances positive-class instances to which the prediction value applies
* @param negInstances negative-class instances to which the prediction value applies
* @param predictionValue the new prediction value
*/
private void updateWeights(Instances posInstances, Instances negInstances,
double predictionValue) {
// do positives
double weightMultiplier = Math.pow(Math.E, -predictionValue);
for (Enumeration e = posInstances.emerateInstances(); e.hasMoreElements(); ) {
Instance inst = (Instance) e.nextElement();
inst.setWeight(inst.weight() * weightMultiplier);
}
// do negatives
weightMultiplier = Math.pow(Math.E, predictionValue);
for (Enumeration e = negInstances.emerateInstances(); e.hasMoreElements(); ) {
Instance inst = (Instance) e.nextElement();
inst.setWeight(inst.weight() * weightMultiplier);
}
}
/**
* Finds the nominal attribute value to split on that results in the lowest Z-value.
*
* @param posInstances the positive-class instances to split
* @param negInstances the negative-class instances to split
* @param attIndex the index of the nominal attribute to find a split for
* @return a double array, index[0] contains the value to split on, index[1] contains
* the Z-value of the split
*/
private double[] findLowestZNominalSplit(Instances posInstances, Instances negInstances,
int attIndex)
{
double lowestZ = Double.MAX_VALUE;
int bestIndex = 0;
// set up arrays
double[] posWeights = attributeValueWeights(posInstances, attIndex);
double[] negWeights = attributeValueWeights(negInstances, attIndex);
double posWeight = Utils.sum(posWeights);
double negWeight = Utils.sum(negWeights);
int maxIndex = posWeights.length;
if (maxIndex == 2) maxIndex = 1; // avoid repeating due to 2-way symmetry
for (int i = 0; i < maxIndex; i++) {
// calculate Z
double w1 = posWeights[i] + 1.0;
double w2 = negWeights[i] + 1.0;
double w3 = posWeight - w1 + 2.0;
double w4 = negWeight - w2 + 2.0;
double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4);
double newZ = (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder;
// record best option
if (newZ < lowestZ) {
lowestZ = newZ;
bestIndex = i;
}
}
// return result
double[] indexAndZ = new double[2];
indexAndZ[0] = (double) bestIndex;
indexAndZ[1] = lowestZ;
return indexAndZ;
}
/**
* Simultanously sum the weights of all attribute values for all instances.
*
* @param instances the instances to get the weights from
* @param attIndex index of the attribute to be evaluated
* @return a double array containing the weight of each attribute value
*/
private double[] attributeValueWeights(Instances instances, int attIndex)
{
double[] weights = new double[instances.attribute(attIndex).numValues()];
for(int i = 0; i < weights.length; i++) weights[i] = 0.0;
for (Enumeration e = instances.emerateInstances(); e.hasMoreElements(); ) {
Instance inst = (Instance) e.nextElement();
if (!inst.isMissing(attIndex)) weights[(int)inst.value(attIndex)] += inst.weight();
}
return weights;
}
/**
* Finds the numeric split-point that results in the lowest Z-value.
*
* @param instances the instances to find a split for
* @param attIndex the index of the numeric attribute to find a split for
* @return a double array, index[0] contains the split-point, index[1] contains the
* Z-value of the split
*/
private double[] findLowestZNumericSplit(Instances instances, int attIndex)
throws Exception {
double splitPoint = 0.0;
double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
int numMissing = 0;
double[][] distribution = new double[3][instances.numClasses()];
// compute counts for all the values
for (int i = 0; i < instances.numInstances(); i++) {
Instance inst = instances.instance(i);
if (!inst.isMissing(attIndex)) {
distribution[1][(int)inst.classValue()] += inst.weight();
} else {
distribution[2][(int)inst.classValue()] += inst.weight();
numMissing++;
}
}
// sort instances
instances.sort(attIndex);
// make split counts for each possible split and evaluate
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -