📄 bftree.java
字号:
// If gini gain is less than that of last node in FastVector if (gGain<((Double)(lastNode.elementAt(3))).doubleValue()) { BestFirstElements.insertElementAt(splitInfo, vectorSize); } else { for (int j=0; j<vectorSize; j++) { FastVector node = (FastVector)BestFirstElements.elementAt(j); double nodeGain = ((Double)(node.elementAt(3))).doubleValue(); if (gGain>=nodeGain) { BestFirstElements.insertElementAt(splitInfo, j); break; } } } } } } /** * Compute sorted indices, weights and class probabilities for a given * dataset. Return total weights of the data at the node. * * @param data training data * @param sortedIndices sorted indices of instances at the node * @param weights weights of instances at the node * @param classProbs class probabilities at the node * @return total weights of instances at the node * @throws Exception if something goes wrong */ protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights, double[] classProbs) throws Exception { // Create array of sorted indices and weights double[] vals = new double[data.numInstances()]; for (int j = 0; j < data.numAttributes(); j++) { if (j==data.classIndex()) continue; weights[j] = new double[data.numInstances()]; if (data.attribute(j).isNominal()) { // Handling nominal attributes. Putting indices of // instances with missing values at the end. sortedIndices[j] = new int[data.numInstances()]; int count = 0; for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); if (!inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); if (inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } } else { // Sorted indices are computed for numeric attributes // missing values instances are put to end (through Utils.sort() method) for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); vals[i] = inst.value(j); } sortedIndices[j] = Utils.sort(vals); for (int i = 0; i < data.numInstances(); i++) { weights[j][i] = data.instance(sortedIndices[j][i]).weight(); } } } // Compute initial class counts and total weight double totalWeight = 0; for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); classProbs[(int)inst.classValue()] += inst.weight(); totalWeight += inst.weight(); } return totalWeight; } /** * Compute the best splitting attribute, split point or subset and the best * gini gain or iformation gain for a given dataset. * * @param node node to be split * @param data training data * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param dists class distributions for each attribute * @param props proportions of two branches * @param totalSubsetWeights total weight of two subsets * @param useHeuristic if use heuristic search for nominal attributes * in multi-class problem * @param useGini if use Gini index as splitting criterion * @return split information about the node * @throws Exception if something is wrong */ protected FastVector computeSplitInfo(BFTree node, Instances data, int[][] sortedIndices, double[][] weights, double[][][] dists, double[][] props, double[][] totalSubsetWeights, boolean useHeuristic, boolean useGini) throws Exception { double[] splits = new double[data.numAttributes()]; String[] splitString = new String[data.numAttributes()]; double[] gains = new double[data.numAttributes()]; for (int i = 0; i < data.numAttributes(); i++) { if (i==data.classIndex()) continue; Attribute att = data.attribute(i); if (att.isNumeric()) { // numeric attribute splits[i] = numericDistribution(props, dists, att, sortedIndices[i], weights[i], totalSubsetWeights, gains, data, useGini); } else { // nominal attribute splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i], weights[i], totalSubsetWeights, gains, data, useHeuristic, useGini); } } int index = Utils.maxIndex(gains); double mBestGain = gains[index]; Attribute att = data.attribute(index); double mValue =Double.NaN; String mString = null; if (att.isNumeric()) mValue= splits[index]; else { mString = splitString[index]; if (mString==null) mString = ""; } // split information FastVector splitInfo = new FastVector(); splitInfo.addElement(node); splitInfo.addElement(att); if (att.isNumeric()) splitInfo.addElement(new Double(mValue)); else splitInfo.addElement(mString); splitInfo.addElement(new Double(mBestGain)); return splitInfo; } /** * Compute distributions, proportions and total weights of two successor nodes for * a given numeric attribute. * * @param props proportions of each two branches for each attribute * @param dists class distributions of two branches for each attribute * @param att numeric att split on * @param sortedIndices sorted indices of instances for the attirubte * @param weights weights of instances for the attirbute * @param subsetWeights total weight of two branches split based on the attribute * @param gains Gini gains or information gains for each attribute * @param data training instances * @param useGini if use Gini index as splitting criterion * @return Gini gain or information gain for the given attribute * @throws Exception if something goes wrong */ protected double numericDistribution(double[][] props, double[][][] dists, Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights, double[] gains, Instances data, boolean useGini) throws Exception { double splitPoint = Double.NaN; double[][] dist = null; int numClasses = data.numClasses(); int i; // differ instances with or without missing values double[][] currDist = new double[2][numClasses]; dist = new double[2][numClasses]; // Move all instances without missing values into second subset double[] parentDist = new double[numClasses]; int missingStart = 0; for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (!inst.isMissing(att)) { missingStart ++; currDist[1][(int)inst.classValue()] += weights[j]; } parentDist[(int)inst.classValue()] += weights[j]; } System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length); // Try all possible split points double currSplit = data.instance(sortedIndices[0]).value(att); double currGain; double bestGain = -Double.MAX_VALUE; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (inst.value(att) > currSplit) { double[][] tempDist = new double[2][numClasses]; for (int k=0; k<2; k++) { //tempDist[k] = currDist[k]; System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length); } double[] tempProps = new double[2]; for (int k=0; k<2; k++) { tempProps[k] = Utils.sum(tempDist[k]); } if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps); // split missing values int index = missingStart; while (index < sortedIndices.length) { Instance insta = data.instance(sortedIndices[index]); for (int j = 0; j < 2; j++) { tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index]; } index++; } if (useGini) currGain = computeGiniGain(parentDist,tempDist); else currGain = computeInfoGain(parentDist,tempDist); if (currGain > bestGain) { bestGain = currGain; // clean split point splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0; for (int j = 0; j < currDist.length; j++) { System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length); } } } currSplit = inst.value(att); currDist[0][(int)inst.classValue()] += weights[i]; currDist[1][(int)inst.classValue()] -= weights[i]; } // Compute weights int attIndex = att.index(); props[attIndex] = new double[2]; for (int k = 0; k < 2; k++) { props[attIndex][k] = Utils.sum(dist[k]); } if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]); // Compute subset weights subsetWeights[attIndex] = new double[2]; for (int j = 0; j < 2; j++) { subsetWeights[attIndex][j] += Utils.sum(dist[j]); } // clean gain gains[attIndex] = Math.rint(bestGain*10000000)/10000000.0; dists[attIndex] = dist; return splitPoint; } /** * Compute distributions, proportions and total weights of two successor * nodes for a given nominal attribute. * * @param props proportions of each two branches for each attribute * @param dists class distributions of two branches for each attribute * @param att numeric att split on * @param sortedIndices sorted indices of instances for the attirubte * @param weights weights of instances for the attirbute * @param subsetWeights total weight of two branches split based on the attribute * @param gains Gini gains for each attribute * @param data training instances * @param useHeuristic if use heuristic search * @param useGini if use Gini index as splitting criterion * @return Gini gain for the given attribute * @throws Exception if something goes wrong */ protected String nominalDistribution(double[][] props, double[][][] dists, Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights, double[] gains, Instances data, boolean useHeuristic, boolean useGini) throws Exception { String[] values = new String[att.numValues()]; int numCat = values.length; // number of values of the attribute int numClasses = data.numClasses(); String bestSplitString = ""; double bestGain = -Double.MAX_VALUE; // class frequency for each value int[] classFreq = new int[numCat]; for (int j=0; j<numCat; j++) classFreq[j] = 0; double[] parentDist = new double[numClasses]; double[][] currDist = new double[2][numClasses]; double[][] dist = new double[2][numClasses]; int missingStart = 0; for (int i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (!inst.isMissing(att)) { missingStart++; classFreq[(int)inst.value(att)] ++; } parentDist[(int)inst.classValue()] += weights[i]; } // count the number of values that class frequency is not 0 int nonEmpty = 0; for (int j=0; j<numCat; j++) { if (classFreq[j]!=0) nonEmpty ++; } // attribute values which class frequency is not 0 String[] nonEmptyValues = new String[nonEmpty]; int nonEmptyIndex = 0; for (int j=0; j<numCat; j++) { if (classFreq[j]!=0) { nonEmptyValues[nonEmptyIndex] = att.value(j); nonEmptyIndex ++; } } // attribute values which class frequency is 0 int empty = numCat - nonEmpty; String[] emptyValues = new String[empty]; int emptyIndex = 0; for (int j=0; j<numCat; j++) { if (classFreq[j]==0) { emptyValues[emptyIndex] = att.value(j); emptyIndex ++; } } if (nonEmpty<=1) { gains[att.index()] = 0; return ""; } // for tow-class probloms if (data.numClasses()==2) { //// Firstly, for attribute values which class frequency is not zero // probability of class 0 for each attribute value double[] pClass0 = new double[nonEmpty]; // class distribution for each attribute value double[][] valDist = new double[nonEmpty][2]; for (int j=0; j<nonEmpty; j++) { for (int k=0; k<2; k++) { valDist[j][k] = 0; } } for (int i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } for (int j=0; j<nonEmpty; j++) { if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) { valDist[j][(int)inst.classValue()] += inst.weight(); break; } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -