⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 continuousattributedelegate.java

📁 Fast implementation of C4/5 in Java
💻 JAVA
字号:
/**
 * @(#)ContinuousAttributeDelegate.java      1.5.0 09/01/18
 */

package ml.classifier.dt;

import java.util.BitSet;
import ml.dataset.ContinuousAttribute;
import ml.util.Statistics;

/**
 * A delegate of a continuous attribute, containing some essential processed
 * information of the continuous attribute to speed up the tree building process.
 *
 * @author Ping He
 * @author Xiaohua Xu
 * @see ml.classifier.dt.DiscreteAttributeDelegate
 * @see ml.classifier.dt.DecisionTree#build()
 */
public class ContinuousAttributeDelegate extends AttributeDelegate{
	// The attribute values on the corresponding attribute
	private float[] sortedData;

	// The original ids of the sorted continuous values in data array
	private int[] id;
	// The ranks of the continuous values in original sequence
	private int[] rank;
	// The buckets for bucket-sorting in tree construction
	private BitSet buckets;

	// The following two structures are actually shared by all attributes and generated by TreeBuilder
	// The sequence of the train data
	private int[] cases;
	// The weight of each train data
	private float[] weight;

	/**
	 * Initialize a delegate for the specified continuous attribute.
	 * <p>
	 * The initialization of a continuous attribute delegate is a preprocessing
	 * of the attribute values on the continuous attribute.<br>
	 * It mainly indirect-sorts the attribute values, extracts the mapping relationship 
	 * between the original data sequence and the sorted data sequence.
	 * </p>
	 * @param attribute The corrsponding discrete attribute
	 */
	public ContinuousAttributeDelegate(ContinuousAttribute attribute) {
		super();

		String[] data = attribute.getData();
		// Parse the String[] type attribute values to float[]
		this.sortedData = new float[data.length];
		for(int i = 0; i < sortedData.length; i ++) {
			if(data[i].equals("?")) {
				sortedData[i] = Float.NEGATIVE_INFINITY;
				setHasMissingData(true);
			}
			else sortedData[i] = Float.parseFloat(data[i]);
		}

		// id records the new arrangement of the original data
		this.id = Statistics.indirectSort(sortedData);
		// rank records the ranks of each original data
		this.rank = new int[sortedData.length];
		// If there are missing data, their ranks are set all -1
		int knownIndex = 0;
		for(int i = 0; i < sortedData.length && sortedData[i] == Float.NEGATIVE_INFINITY; i ++) {
			knownIndex ++;
			rank[id[i]] = -1;
		}
		// For the rest of the known data, their rank values start from the knownIndex
		// to make sure rank[id[j]] = j, so that bucket sorting can be correctly executed
		for(int j = knownIndex; j < id.length; j ++) {
			rank[id[j]] = j;
		}

		// Prepare for the bucket sorting for each tree node construction
		this.buckets = new BitSet(sortedData.length);
	}

	public void setCasesWeight(int[] casesValue, float[] weightValue) {
		this.cases = casesValue;
		this.weight = weightValue;
	}

	/**
	 * @return If the attribute is evaluated as an invalid test attribute, then <i>null</i> is returned;<br>
	 *         Otherwise, a 1-by-4 float array with<br>
	 *         &nbsp;&nbsp;&nbsp; the 1<sup>st</sup> element recording the Gain,<br>
	 *         &nbsp;&nbsp;&nbsp; the 2<sup>nd</sup> element recording the splitInfo and<br>
	 *         &nbsp;&nbsp;&nbsp; the 3<sup>rd</sup> and 4<sup>th</sup> elements recording the two ranks 
	 *                            the average of whose corresponding values are the best split value.
	 *
	 * @see ml.classifier.dt.GainCalculator
	 */
	public float[] evaluate(int first, int last, AttributeDelegate classAttributeDelegate) {
		// This variable records the total weight of the [first last) cases
		float totalWeight = 0.0f;
	    // This variable records the weight distribution of the [first last)
	    // cases in different branches of the current attribute
		float[] branchDistri = new float[3];

		// This variable records the weight distribution of the [first last)
     	// cases in different classes of the different branches of the attribute
		float[][] branchClassDistribution = new float[3][classAttributeDelegate.getBranchCount()];
		// The minimal weight of the known cases
		float MINKNOWNWEIGHT = Parameter.MINWEIGHT;
		double PRECISION = Parameter.PRECISION;

		// Minimal rank value of the cases from first to last(exclusive)
    	int minRank = sortedData.length;
    	int maxRank = -1;

		/* Initialize the branchDistribition and its branchClassDistrition
         * i.e. Distribute all the [first, last) cases into the right branch
         * and compute its class distribution.
         *
         * At the same time, bucket sorting the cases with BitSet buckets
         */
        for(int i = first ; i < last; i ++)  {
        	totalWeight += weight[cases[i]];
        	// The class attribute has no missing value
    		int classLabel = classAttributeDelegate.getClassBranch(cases[i]);

			// rank < 0 means missing data
    		if(rank[cases[i]] < 0 ) {
    			branchDistri[0] += weight[cases[i]];
                branchClassDistribution[0][classLabel] += weight[cases[i]];
            }
            else {
                branchDistri[2] += weight[cases[i]];
            	branchClassDistribution[2][classLabel] += weight[cases[i]];

	        	// BucketSort the [first last) cases
	        	buckets.set(rank[cases[i]]);
				// Find the minimal and maximal rank values
				if(rank[cases[i]] < minRank) minRank = rank[cases[i]];
				if(rank[cases[i]] > maxRank) maxRank = rank[cases[i]];
            }
    	}

		// Compute the weight of the known cases and its ratio
	    float knownWeight = totalWeight - branchDistri[0];
	    float unknownRatio = branchDistri[0] / totalWeight;

		// If there is too much missing data on this attribute, just try the next attribute
        if(knownWeight < 2 * MINKNOWNWEIGHT) {
        	// If there is any known case, clear the BitSet
			if(minRank <= maxRank) buckets.clear(minRank, maxRank+1);
        	return null;
        }

		// Compute the entropy of the tree node as a Leaf
        float stateEntropy = GainCalculator.computeStateEntropy(branchClassDistribution, knownWeight);

		// Set the minimum weight for each branch of the attribute
        float minBranchWeight = 0.1f * knownWeight / classAttributeDelegate.getBranchCount();
		minBranchWeight  = (minBranchWeight < MINKNOWNWEIGHT) ? MINKNOWNWEIGHT :
		                   (minBranchWeight > 25) ?  25 : minBranchWeight;

		// Ready to record the maximal Gain and its corresponding splitInfo and two ranks for continuous attribute
		float maxGain = Float.NEGATIVE_INFINITY, bestSplitInfo = -1;
		int bestSplitRank = -1, bestPreSplitRank = -1;

		// The number of tries for finding the bet split value
		int tries = 0;
        // The previous rank and its corresponding attribute value for the computation of split value
        int preRank = minRank;
        // The rank value for the update of weight distribution
        int currentRank = preRank;
		int nextRank;

		do{
			// Update the branch distribution and its class distribution
			int caseIndex = id[currentRank];
			float caseWeight = weight[caseIndex];
			int classLabel = classAttributeDelegate.getClassBranch(caseIndex);
	        branchDistri[1] += caseWeight;
			branchDistri[2] -= caseWeight;
			branchClassDistribution[1][classLabel] += caseWeight;
			branchClassDistribution[2][classLabel] -= caseWeight;

	        float currentValue = sortedData[currentRank];
	        nextRank = buckets.nextSetBit(currentRank+1);
	        float nextValue = sortedData[nextRank];

			// Each branch weight must be equal or greater than minBranchWeight
			// For the left branch, it omitted the first several values
			if(branchDistri[1] <= minBranchWeight - PRECISION){
				// If the two values are not the same, update preRank
				if(currentValue != nextValue) {
					preRank = nextRank;
				}
			}
			// If the previous value is very near to the current value, do not change preRank
			else if(currentValue > nextValue - PRECISION){
				// do nothing
			}
			// For the right branch, it omitted the last several values
			else if(branchDistri[2] <= minBranchWeight - PRECISION) {
				break;
			}
			else{
				// Begin to evaluate the current split value
				tries ++;
				// Compute Gain for the current branch weight distribution
				float tempGain = GainCalculator.computeGain(stateEntropy, branchDistri, branchClassDistribution, unknownRatio);
				if(tempGain >= maxGain + PRECISION) {
					maxGain = tempGain;
					bestSplitInfo = GainCalculator.computeSplitInfo(branchDistri, totalWeight);
					bestSplitRank = nextRank;
					bestPreSplitRank = preRank;
				}
				preRank = nextRank;
			}
			currentRank = nextRank;
		}
		while(nextRank < maxRank);

		// Clear the sort result
 		buckets.clear(minRank,maxRank+1);

 		// Compute the threshold value according to the Information theory
        float threshCost = GainCalculator.log(tries)/totalWeight;
        // The adjusted Gain should be the maximal Gain minus the threshold
        float adjustedGain = maxGain - threshCost;

        float[] result = null;
        // If the adjustedGain is still valid, record the related information
        if(adjustedGain > 0) {
        	result = new float[]{adjustedGain, bestSplitInfo, bestSplitRank, bestPreSplitRank};
        }
        // If the adjustedGain is invalid, this attribute has no Gain
        return result;
	}

	public int groupForward(int begin, int last, int groupBranch, float[] branchDistri) {
		// rank -1 is kept for missing data
		int branchIndex = (groupBranch == -1) ? 0 : 1;
		int cutRank = groupBranch;

		int i, j;
		for(i = begin, j = last - 1; i <= j; ) {
			while(i <= j && rank[cases[i]] <= cutRank) {
				branchDistri[branchIndex] += weight[cases[i]];
				i ++;
			}
			while(i <= j && rank[cases[j]] > cutRank) {
				j --;
			}

			if(i <= j) {
				int tmp = cases[i];
				cases[i] = cases[j];
				cases[j] = tmp;

				branchDistri[branchIndex] += weight[cases[i]];
				i ++;
				j --;
			}
		}
		return i;
	}

	public int groupBackward(int begin, int last) {
		int i, j;
		int index = 0;
		int cutRank = -1;

		for(i = last-1, j = begin; i >= j; ) {
			while(i >= j && rank[cases[i]] <= cutRank) {
				i --;
			}
			while(i >= j && rank[cases[j]] > cutRank) {
				j ++;
			}

			if(i >= j) {
				int tmp = cases[i];
				cases[i] = cases[j];
				cases[j] = tmp;

				i --;
				j ++;
			}
		}
		return i + 1;
	}

	/**
	 * Find the rank of the cut value in the test attribute.
	 */
	public int findCutRank(int splitRank, int preSplitRank) {
		float localThreshold = (sortedData[preSplitRank] + sortedData[splitRank])/2;
		int low = preSplitRank, high = splitRank;

		while(low <= high) {
			int mid = (low + high)/2;
			if(sortedData[mid] > localThreshold) high = mid-1;
			else if(sortedData[mid] <= localThreshold) low = mid + 1;
		}
		return low-1;
	}

	/**
	 * Find the cut value of the test attribute when provided with its rank.
	 */
	public float findCut(int cutRank) {
		return sortedData[cutRank];
	}

	public int getBranchCount(){
		return 2;
	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -