📄 simplecart.java
字号:
} //calculate the vector of mean class probability for (int i=0; i<meanClass.length; i++) { meanClass[i]/=numInstances; } // calculate the covariance matrix double[][] covariance = new double[k][k]; for (int i1=0; i1<k; i1++) { for (int i2=0; i2<k; i2++) { double element = 0; for (int j=0; j<n; j++) { element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1]) *numInstancesValue[j]; } covariance[i1][i2] = element; } } Matrix matrix = new Matrix(covariance); weka.core.matrix.EigenvalueDecomposition eigen = new weka.core.matrix.EigenvalueDecomposition(matrix); double[] eigenValues = eigen.getRealEigenvalues(); // find index of the largest eigenvalue int index=0; double largest = eigenValues[0]; for (int i=1; i<eigenValues.length; i++) { if (eigenValues[i]>largest) { index=i; largest = eigenValues[i]; } } // calculate the first principle component double[] FPC = new double[k]; Matrix eigenVector = eigen.getV(); double[][] vectorArray = eigenVector.getArray(); for (int i=0; i<FPC.length; i++) { FPC[i] = vectorArray[i][index]; } // calculate the first principle component scores //System.out.println("the first principle component scores: "); double[] Sa = new double[n]; for (int i=0; i<Sa.length; i++) { Sa[i]=0; for (int j=0; j<k; j++) { Sa[i] += FPC[j]*P[i][j]; } } // sort category according to Sa(s) double[] pCopy = new double[n]; System.arraycopy(Sa,0,pCopy,0,n); String[] sortedValues = new String[n]; Arrays.sort(Sa); for (int j=0; j<n; j++) { sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)]; pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE; } // for the attribute values that class frequency is not 0 String tempStr = ""; for (int j=0; j<nonEmpty-1; j++) { currDist = new double[2][numClasses]; if (tempStr=="") tempStr="(" + sortedValues[j] + ")"; else tempStr += "|"+ "(" + sortedValues[j] + ")"; for (int i=0; i<sortedIndices.length;i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (tempStr.indexOf ("(" + att.value((int)inst.value(att)) + ")")!=-1) { currDist[0][(int)inst.classValue()] += weights[i]; } else currDist[1][(int)inst.classValue()] += weights[i]; } double[][] tempDist = new double[2][numClasses]; for (int kk=0; kk<2; kk++) { tempDist[kk] = currDist[kk]; } double[] tempProps = new double[2]; for (int kk=0; kk<2; kk++) { tempProps[kk] = Utils.sum(tempDist[kk]); } if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps); // split missing values int mstart = missingStart; while (mstart < sortedIndices.length) { Instance insta = data.instance(sortedIndices[mstart]); for (int jj = 0; jj < 2; jj++) { tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart]; } mstart++; } double currGiniGain = computeGiniGain(parentDist,tempDist); if (currGiniGain>bestGiniGain) { bestGiniGain = currGiniGain; bestSplitString = tempStr; for (int jj = 0; jj < 2; jj++) { //dist[jj] = new double[currDist[jj].length]; System.arraycopy(tempDist[jj], 0, dist[jj], 0, dist[jj].length); } } } } // Compute weights int attIndex = att.index(); props[attIndex] = new double[2]; for (int k = 0; k < 2; k++) { props[attIndex][k] = Utils.sum(dist[k]); } if (!(Utils.sum(props[attIndex]) > 0)) { for (int k = 0; k < props[attIndex].length; k++) { props[attIndex][k] = 1.0 / (double)props[attIndex].length; } } else { Utils.normalize(props[attIndex]); } // Compute subset weights subsetWeights[attIndex] = new double[2]; for (int j = 0; j < 2; j++) { subsetWeights[attIndex][j] += Utils.sum(dist[j]); } // Then, for the attribute values that class frequency is 0, split it into the // most frequent branch for (int j=0; j<empty; j++) { if (props[attIndex][0]>=props[attIndex][1]) { if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")"; else bestSplitString += "|" + "(" + emptyValues[j] + ")"; } } // clean Gini gain for the attribute giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0; dists[attIndex] = dist; return bestSplitString; } /** * Split data into two subsets and store sorted indices and weights for two * successor nodes. * * @param subsetIndices sorted indecis of instances for each attribute * for two successor node * @param subsetWeights weights of instances for each attribute for * two successor node * @param att attribute the split based on * @param splitPoint split point the split based on if att is numeric * @param splitStr split subset the split based on if att is nominal * @param sortedIndices sorted indices of the instances to be split * @param weights weights of the instances to bes split * @param data training data * @throws Exception if something goes wrong */ protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights, Attribute att, double splitPoint, String splitStr, int[][] sortedIndices, double[][] weights, Instances data) throws Exception { int j; // For each attribute for (int i = 0; i < data.numAttributes(); i++) { if (i==data.classIndex()) continue; int[] num = new int[2]; for (int k = 0; k < 2; k++) { subsetIndices[k][i] = new int[sortedIndices[i].length]; subsetWeights[k][i] = new double[weights[i].length]; } for (j = 0; j < sortedIndices[i].length; j++) { Instance inst = data.instance(sortedIndices[i][j]); if (inst.isMissing(att)) { // Split instance up for (int k = 0; k < 2; k++) { if (m_Props[k] > 0) { subsetIndices[k][i][num[k]] = sortedIndices[i][j]; subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j]; num[k]++; } } } else { int subset; if (att.isNumeric()) { subset = (inst.value(att) < splitPoint) ? 0 : 1; } else { // nominal attribute if (splitStr.indexOf ("(" + att.value((int)inst.value(att.index()))+")")!=-1) { subset = 0; } else subset = 1; } subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; subsetWeights[subset][i][num[subset]] = weights[i][j]; num[subset]++; } } // Trim arrays for (int k = 0; k < 2; k++) { int[] copy = new int[num[k]]; System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]); subsetIndices[k][i] = copy; double[] copyWeights = new double[num[k]]; System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]); subsetWeights[k][i] = copyWeights; } } } /** * Updates the numIncorrectModel field for all nodes when subtree (to be * pruned) is rooted. This is needed for calculating the alpha-values. * * @throws Exception if something goes wrong */ public void modelErrors() throws Exception{ Evaluation eval = new Evaluation(m_train); if (!m_isLeaf) { m_isLeaf = true; //temporarily make leaf // calculate distribution for evaluation eval.evaluateModel(this, m_train); m_numIncorrectModel = eval.incorrect(); m_isLeaf = false; for (int i = 0; i < m_Successors.length; i++) m_Successors[i].modelErrors(); } else { eval.evaluateModel(this, m_train); m_numIncorrectModel = eval.incorrect(); } } /** * Updates the numIncorrectTree field for all nodes. This is needed for * calculating the alpha-values. * * @throws Exception if something goes wrong */ public void treeErrors() throws Exception { if (m_isLeaf) { m_numIncorrectTree = m_numIncorrectModel; } else { m_numIncorrectTree = 0; for (int i = 0; i < m_Successors.length; i++) { m_Successors[i].treeErrors(); m_numIncorrectTree += m_Successors[i].m_numIncorrectTree; } } } /** * Updates the alpha field for all nodes. * * @throws Exception if something goes wrong */ public void calculateAlphas() throws Exception { if (!m_isLeaf) { double errorDiff = m_numIncorrectModel - m_numIncorrectTree; if (errorDiff <=0) { //split increases training error (should not normally happen). //prune it instantly. makeLeaf(m_train); m_Alpha = Double.MAX_VALUE; } else { //compute alpha errorDiff /= m_totalTrainInstances; m_Alpha = errorDiff / (double)(numLeaves() - 1); long alphaLong = Math.round(m_Alpha*Math.pow(10,10)); m_Alpha = (double)alphaLong/Math.pow(10,10); for (int i = 0; i < m_Successors.length; i++) { m_Successors[i].calculateAlphas(); } } } else { //alpha = infinite for leaves (do not want to prune) m_Alpha = Double.MAX_VALUE; } } /** * Find the node with minimal alpha value. If two nodes have the same alpha, * choose the one with more leave nodes. * * @param nodeList list of inner nodes * @return the node to be pruned */ protected SimpleCart nodeToPrune(Vector nodeList) { if (nodeList.size()==0) return null; if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0); SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0); double baseAlpha = returnNode.m_Alpha; for (int i=1; i<nodeList.size(); i++) { SimpleCart node = (SimpleCart)nodeList.elementAt(i); if (node.m_Alpha < baseAlpha) { baseAlpha = node.m_Alpha; returnNode = node; } else if (node.m_Alpha == baseAlpha) { // break tie if (node.numLeaves()>returnNode.numLeaves()) { returnNode = node; } } } return returnNode; } /** * Compute sorted indices, weights and class probabilities for a given * dataset. Return total weights of the data at the node. * * @param data training data * @param sortedIndices sorted indices of instances at the node * @param weights weights of instances at the node * @param classProbs class probabilities at the node * @return total weights of instances at the node * @throws Exception if something goes wrong */ protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights, double[] classProbs) throws Exception { // Create array of sorted indices and weights double[] vals = new double[data.numInstances()]; for (int j = 0; j < data.numAttributes(); j++) { if (j==data.classIndex()) continue; weights[j] = new double[data.numInstances()]; if (data.attribute(j).isNominal()) { // Handling nominal attributes. Putting indices of // instances with missing values at the end. sortedIndices[j] = new int[data.numInstances()]; int count = 0; for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); if (!inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); if (inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } } else { // Sorted indices are computed for numeric attributes // missing values instances are put to end for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); vals[i] = inst.value(j); } sortedIndices[j] = Utils.sort(vals); for (int i = 0; i < data.numInstances(); i++) { weights[j][i] = data.instance(sortedIndices[j][i]).weight(); } } } // Compute initial class counts double totalWeight = 0; for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); classProbs[(int)inst.classValue()] += inst.weight(); totalWeight += inst.weight(); } return totalWeight; } /** * Compute and return gini gain for given distributions of a node and its * successor nodes. * * @param parentDist class distributions of parent node * @param childDist class distributions of successor nodes * @return Gini gain computed */ protected double computeGiniGain(double[] parentDist, double[][] childDist) { double totalWeight = Utils.sum(parentDist); if (totalWeight==0) return 0; double leftWeight = Utils.sum(childDist[0]); double rightWeight = Utils.sum(childDist[1]); double parentGini = computeGini(parentDist, totalWeight); double leftGini = computeGini(childDist[0],leftWeight); double rightGini = computeGini(childDist[1], rightWeight); return parentGini - leftWeight/totalWeight*leftGini - rightWeight/totalWeight*rightGini; } /** * Compute and return gini index for a given distribution of a node. * * @param dist class distributions * @param total class distributions * @return Gini index of the class distributions */ protected double computeGini(double[] dist, double total) { if (total==0) return 0; double val = 0; for (int i=0; i<dist.length; i++) { val += (dist[i]/total)*(dist[i]/total); } return 1- val; } /** * Computes class probabilities for instance using the decision tree. * * @param instance the instance for which class probabilities is to be computed * @return the class probabilities for the given instance * @throws Exception if something goes wrong */ public double[] distributionForInstance(Instance instance) throws Exception { if (!m_isLeaf) { // value of split attribute is missing if (instance.isMissing(m_Attribute)) { double[] returnedDist = new double[m_ClassProbs.length]; for (int i = 0; i < m_Successors.length; i++) { double[] help = m_Successors[i].distributionForInstance(instance); if (help != null) { for (int j = 0; j < help.length; j++) { returnedDist[j] += m_Props[i] * help[j]; } } } return returnedDist; } // split attribute is nonimal else if (m_Attribute.isNominal()) { if (m_SplitString.indexOf("(" + m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1) return m_Successors[0].distributionForInstance(instance); else return m_Successors[1].distributionForInstance(instance); } // split attribute is numeric else { if (instance.value(m_Attribute) < m_SplitValue) return m_Successors[0].distributionForInstance(instance); else return m_Successors[1].distributionForInstance(instance); } } // leaf node else return m_ClassProbs; } /** * Make the node leaf node. * * @param data trainging data */ protected void makeLeaf(Instances data) { m_Attribute = null; m_isLeaf = true; m_ClassValue=Utils.maxIndex(m_ClassProbs); m_ClassAttribute = data.classAttribute(); } /** * Prints the decision tree using the protected toString method from below. * * @return a textual description of the classifier */ public String toString() { if ((m_ClassProbs == null) && (m_Successors == null)) { return "CART Tree: No model built yet."; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -