📄 continuousattributedelegate.java
字号:
/**
* @(#)ContinuousAttributeDelegate.java 1.5.0 09/01/18
*/
package ml.classifier.dt;
import java.util.BitSet;
import ml.dataset.ContinuousAttribute;
import ml.util.Statistics;
/**
* A delegate of a continuous attribute, containing some essential processed
* information of the continuous attribute to speed up the tree building process.
*
* @author Ping He
* @author Xiaohua Xu
* @see ml.classifier.dt.DiscreteAttributeDelegate
* @see ml.classifier.dt.DecisionTree#build()
*/
public class ContinuousAttributeDelegate extends AttributeDelegate{
// The attribute values on the corresponding attribute
private float[] sortedData;
// The original ids of the sorted continuous values in data array
private int[] id;
// The ranks of the continuous values in original sequence
private int[] rank;
// The buckets for bucket-sorting in tree construction
private BitSet buckets;
// The following two structures are actually shared by all attributes and generated by TreeBuilder
// The sequence of the train data
private int[] cases;
// The weight of each train data
private float[] weight;
/**
* Initialize a delegate for the specified continuous attribute.
* <p>
* The initialization of a continuous attribute delegate is a preprocessing
* of the attribute values on the continuous attribute.<br>
* It mainly indirect-sorts the attribute values, extracts the mapping relationship
* between the original data sequence and the sorted data sequence.
* </p>
* @param attribute The corrsponding discrete attribute
*/
public ContinuousAttributeDelegate(ContinuousAttribute attribute) {
super();
String[] data = attribute.getData();
// Parse the String[] type attribute values to float[]
this.sortedData = new float[data.length];
for(int i = 0; i < sortedData.length; i ++) {
if(data[i].equals("?")) {
sortedData[i] = Float.NEGATIVE_INFINITY;
setHasMissingData(true);
}
else sortedData[i] = Float.parseFloat(data[i]);
}
// id records the new arrangement of the original data
this.id = Statistics.indirectSort(sortedData);
// rank records the ranks of each original data
this.rank = new int[sortedData.length];
// If there are missing data, their ranks are set all -1
int knownIndex = 0;
for(int i = 0; i < sortedData.length && sortedData[i] == Float.NEGATIVE_INFINITY; i ++) {
knownIndex ++;
rank[id[i]] = -1;
}
// For the rest of the known data, their rank values start from the knownIndex
// to make sure rank[id[j]] = j, so that bucket sorting can be correctly executed
for(int j = knownIndex; j < id.length; j ++) {
rank[id[j]] = j;
}
// Prepare for the bucket sorting for each tree node construction
this.buckets = new BitSet(sortedData.length);
}
public void setCasesWeight(int[] casesValue, float[] weightValue) {
this.cases = casesValue;
this.weight = weightValue;
}
/**
* @return If the attribute is evaluated as an invalid test attribute, then <i>null</i> is returned;<br>
* Otherwise, a 1-by-4 float array with<br>
* the 1<sup>st</sup> element recording the Gain,<br>
* the 2<sup>nd</sup> element recording the splitInfo and<br>
* the 3<sup>rd</sup> and 4<sup>th</sup> elements recording the two ranks
* the average of whose corresponding values are the best split value.
*
* @see ml.classifier.dt.GainCalculator
*/
public float[] evaluate(int first, int last, AttributeDelegate classAttributeDelegate) {
// This variable records the total weight of the [first last) cases
float totalWeight = 0.0f;
// This variable records the weight distribution of the [first last)
// cases in different branches of the current attribute
float[] branchDistri = new float[3];
// This variable records the weight distribution of the [first last)
// cases in different classes of the different branches of the attribute
float[][] branchClassDistribution = new float[3][classAttributeDelegate.getBranchCount()];
// The minimal weight of the known cases
float MINKNOWNWEIGHT = Parameter.MINWEIGHT;
double PRECISION = Parameter.PRECISION;
// Minimal rank value of the cases from first to last(exclusive)
int minRank = sortedData.length;
int maxRank = -1;
/* Initialize the branchDistribition and its branchClassDistrition
* i.e. Distribute all the [first, last) cases into the right branch
* and compute its class distribution.
*
* At the same time, bucket sorting the cases with BitSet buckets
*/
for(int i = first ; i < last; i ++) {
totalWeight += weight[cases[i]];
// The class attribute has no missing value
int classLabel = classAttributeDelegate.getClassBranch(cases[i]);
// rank < 0 means missing data
if(rank[cases[i]] < 0 ) {
branchDistri[0] += weight[cases[i]];
branchClassDistribution[0][classLabel] += weight[cases[i]];
}
else {
branchDistri[2] += weight[cases[i]];
branchClassDistribution[2][classLabel] += weight[cases[i]];
// BucketSort the [first last) cases
buckets.set(rank[cases[i]]);
// Find the minimal and maximal rank values
if(rank[cases[i]] < minRank) minRank = rank[cases[i]];
if(rank[cases[i]] > maxRank) maxRank = rank[cases[i]];
}
}
// Compute the weight of the known cases and its ratio
float knownWeight = totalWeight - branchDistri[0];
float unknownRatio = branchDistri[0] / totalWeight;
// If there is too much missing data on this attribute, just try the next attribute
if(knownWeight < 2 * MINKNOWNWEIGHT) {
// If there is any known case, clear the BitSet
if(minRank <= maxRank) buckets.clear(minRank, maxRank+1);
return null;
}
// Compute the entropy of the tree node as a Leaf
float stateEntropy = GainCalculator.computeStateEntropy(branchClassDistribution, knownWeight);
// Set the minimum weight for each branch of the attribute
float minBranchWeight = 0.1f * knownWeight / classAttributeDelegate.getBranchCount();
minBranchWeight = (minBranchWeight < MINKNOWNWEIGHT) ? MINKNOWNWEIGHT :
(minBranchWeight > 25) ? 25 : minBranchWeight;
// Ready to record the maximal Gain and its corresponding splitInfo and two ranks for continuous attribute
float maxGain = Float.NEGATIVE_INFINITY, bestSplitInfo = -1;
int bestSplitRank = -1, bestPreSplitRank = -1;
// The number of tries for finding the bet split value
int tries = 0;
// The previous rank and its corresponding attribute value for the computation of split value
int preRank = minRank;
// The rank value for the update of weight distribution
int currentRank = preRank;
int nextRank;
do{
// Update the branch distribution and its class distribution
int caseIndex = id[currentRank];
float caseWeight = weight[caseIndex];
int classLabel = classAttributeDelegate.getClassBranch(caseIndex);
branchDistri[1] += caseWeight;
branchDistri[2] -= caseWeight;
branchClassDistribution[1][classLabel] += caseWeight;
branchClassDistribution[2][classLabel] -= caseWeight;
float currentValue = sortedData[currentRank];
nextRank = buckets.nextSetBit(currentRank+1);
float nextValue = sortedData[nextRank];
// Each branch weight must be equal or greater than minBranchWeight
// For the left branch, it omitted the first several values
if(branchDistri[1] <= minBranchWeight - PRECISION){
// If the two values are not the same, update preRank
if(currentValue != nextValue) {
preRank = nextRank;
}
}
// If the previous value is very near to the current value, do not change preRank
else if(currentValue > nextValue - PRECISION){
// do nothing
}
// For the right branch, it omitted the last several values
else if(branchDistri[2] <= minBranchWeight - PRECISION) {
break;
}
else{
// Begin to evaluate the current split value
tries ++;
// Compute Gain for the current branch weight distribution
float tempGain = GainCalculator.computeGain(stateEntropy, branchDistri, branchClassDistribution, unknownRatio);
if(tempGain >= maxGain + PRECISION) {
maxGain = tempGain;
bestSplitInfo = GainCalculator.computeSplitInfo(branchDistri, totalWeight);
bestSplitRank = nextRank;
bestPreSplitRank = preRank;
}
preRank = nextRank;
}
currentRank = nextRank;
}
while(nextRank < maxRank);
// Clear the sort result
buckets.clear(minRank,maxRank+1);
// Compute the threshold value according to the Information theory
float threshCost = GainCalculator.log(tries)/totalWeight;
// The adjusted Gain should be the maximal Gain minus the threshold
float adjustedGain = maxGain - threshCost;
float[] result = null;
// If the adjustedGain is still valid, record the related information
if(adjustedGain > 0) {
result = new float[]{adjustedGain, bestSplitInfo, bestSplitRank, bestPreSplitRank};
}
// If the adjustedGain is invalid, this attribute has no Gain
return result;
}
public int groupForward(int begin, int last, int groupBranch, float[] branchDistri) {
// rank -1 is kept for missing data
int branchIndex = (groupBranch == -1) ? 0 : 1;
int cutRank = groupBranch;
int i, j;
for(i = begin, j = last - 1; i <= j; ) {
while(i <= j && rank[cases[i]] <= cutRank) {
branchDistri[branchIndex] += weight[cases[i]];
i ++;
}
while(i <= j && rank[cases[j]] > cutRank) {
j --;
}
if(i <= j) {
int tmp = cases[i];
cases[i] = cases[j];
cases[j] = tmp;
branchDistri[branchIndex] += weight[cases[i]];
i ++;
j --;
}
}
return i;
}
public int groupBackward(int begin, int last) {
int i, j;
int index = 0;
int cutRank = -1;
for(i = last-1, j = begin; i >= j; ) {
while(i >= j && rank[cases[i]] <= cutRank) {
i --;
}
while(i >= j && rank[cases[j]] > cutRank) {
j ++;
}
if(i >= j) {
int tmp = cases[i];
cases[i] = cases[j];
cases[j] = tmp;
i --;
j ++;
}
}
return i + 1;
}
/**
* Find the rank of the cut value in the test attribute.
*/
public int findCutRank(int splitRank, int preSplitRank) {
float localThreshold = (sortedData[preSplitRank] + sortedData[splitRank])/2;
int low = preSplitRank, high = splitRank;
while(low <= high) {
int mid = (low + high)/2;
if(sortedData[mid] > localThreshold) high = mid-1;
else if(sortedData[mid] <= localThreshold) low = mid + 1;
}
return low-1;
}
/**
* Find the cut value of the test attribute when provided with its rank.
*/
public float findCut(int cutRank) {
return sortedData[cutRank];
}
public int getBranchCount(){
return 2;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -