c45pruneableclassifiertreeg.java
来自「Weka」· Java 代码 · 共 1,217 行 · 第 1/3 页
JAVA
1,217 行
} /** * Method just exists to make program easier to read. */ private C45PruneableClassifierTreeG son(int index){ return (C45PruneableClassifierTreeG)m_sons[index]; } /** * Initializes variables for grafting. * sets up limits array (for numeric attributes) and calls * the recursive function traverseTree. * * @param data the data for the tree * @throws exception if anything goes wrong */ public void doGrafting(Instances data) throws Exception { // 2d array for the limits double [][] limits = new double[data.numAttributes()][2]; // 2nd dimension: index 0 == lower limit, index 1 == upper limit // initialise to no limit for(int i = 0; i < data.numAttributes(); i++) { limits[i][0] = Double.NEGATIVE_INFINITY; limits[i][1] = Double.POSITIVE_INFINITY; } // use an index instead of creating new Insances objects all the time // instanceIndex[0] == array for weights at leaf // instanceIndex[1] == array for weights in atbop double [][] instanceIndex = new double[2][data.numInstances()]; // initialize the weight for each instance for(int x = 0; x < data.numInstances(); x++) { instanceIndex[0][x] = 1; instanceIndex[1][x] = 1; // leaf instances are in atbop } // first call to graft traverseTree(data, instanceIndex, limits, this, 0, -1); } /** * recursive function. * if this node is a leaf then calls findGraft, otherwise sorts * the two sets of instances (tracked in iindex array) and calls * sortInstances for each of the child nodes (which then calls * this method). * * @param fulldata all instances * @param iindex array the tracks the weight of each instance in * the atbop and at the leaf (0.0 if not present) * @param limits array specifying current upper/lower limits for numeric atts * @param parent the node immediately before the current one * @param pL laplace for node, as calculated by parent (in case leaf is empty) * @param nodeClass class of node, determined by parent (in case leaf empty) */ private void traverseTree(Instances fulldata, double [][] iindex, double[][] limits, C45PruneableClassifierTreeG parent, double pL, int nodeClass) throws Exception { if(m_isLeaf) { findGraft(fulldata, iindex, limits, (ClassifierTree)parent, pL, nodeClass); } else { // traverse each branch for(int i = 0; i < localModel().numSubsets(); i++) { double [][] newiindex = new double[2][fulldata.numInstances()]; for(int x = 0; x < 2; x++) System.arraycopy(iindex[x], 0, newiindex[x], 0, iindex[x].length); sortInstances(fulldata, newiindex, limits, i); } } } /** * sorts/deletes instances into/from node and atbop according to * the test for subset, then calls traverseTree for subset's node. * * @param fulldata all instances * @param iindex array the tracks the weight of each instance in * the atbop and at the leaf (0.0 if not present) * @param limits array specifying current upper/lower limits for numeric atts * @param subset the subset for which to sort instances into inode & iatbop */ private void sortInstances(Instances fulldata, double [][] iindex, double [][] limits, int subset) throws Exception { C45Split test = (C45Split)localModel(); // update the instances index for subset double knownCases = 0; double thisSubsetCount = 0; for(int x = 0; x < iindex[0].length; x++) { if(iindex[0][x] == 0 && iindex[1][x] == 0) // skip "discarded" instances continue; if(!fulldata.instance(x).isMissing(test.attIndex())) { knownCases += iindex[0][x]; if(test.whichSubset(fulldata.instance(x)) != subset) { if(iindex[0][x] > 0) { // move to atbop, delete from leaf iindex[1][x] = iindex[0][x]; iindex[0][x] = 0; } else { if(iindex[1][x] > 0) { // instance is now "discarded" iindex[1][x] = 0; } } } else { thisSubsetCount += iindex[0][x]; } } } // work out proportions of weight for missing values for leaf and atbop double lprop = (knownCases == 0) ? (1.0 / (double)test.numSubsets()) : (thisSubsetCount / (double)knownCases); // add in the instances that have missing value for attIndex for(int x = 0; x < iindex[0].length; x++) { if(iindex[0][x] == 0 && iindex[1][x] == 0) continue; // skip "discarded" instances if(fulldata.instance(x).isMissing(test.attIndex())) { iindex[1][x] -= (iindex[1][x] - iindex[0][x]) * (1-lprop); iindex[0][x] *= lprop; } } int nodeClass = localModel().distribution().maxClass(subset); double pL = (localModel().distribution().perClass(nodeClass) + 1.0) / (localModel().distribution().total() + 2.0); // call traerseTree method for the child node son(subset).traverseTree(fulldata, iindex, test.minsAndMaxs(fulldata, limits, subset), this, pL, nodeClass); } /** * finds new nodes that improve accuracy and grafts them onto the tree * * @param fulldata the instances in whole trainset * @param iindex records num tests each instance has failed up to this node * @param limits the upper/lower limits for numeric attributes * @param parent the node immediately before the current one * @param pLaplace laplace for leaf, calculated by parent (in case leaf empty) * @param pLeafClass class of leaf, determined by parent (in case leaf empty) */ private void findGraft(Instances fulldata, double [][] iindex, double [][] limits, ClassifierTree parent, double pLaplace, int pLeafClass) throws Exception { // get the class for this leaf int leafClass = (m_isEmpty) ? pLeafClass : localModel().distribution().maxClass(); // get the laplace value for this leaf double leafLaplace = (m_isEmpty) ? pLaplace : laplaceLeaf(leafClass); // sort the instances into those at the leaf, those in atbop, and discarded Instances l = new Instances(fulldata, fulldata.numInstances()); Instances n = new Instances(fulldata, fulldata.numInstances()); int lcount = 0; int acount = 0; for(int x = 0; x < fulldata.numInstances(); x++) { if(iindex[0][x] <= 0 && iindex[1][x] <= 0) continue; if(iindex[0][x] != 0) { l.add(fulldata.instance(x)); l.instance(lcount).setWeight(iindex[0][x]); // move instance's weight in iindex to same index as in l iindex[0][lcount++] = iindex[0][x]; } if(iindex[1][x] > 0) { n.add(fulldata.instance(x)); n.instance(acount).setWeight(iindex[1][x]); // move instance's weight in iindex to same index as in n iindex[1][acount++] = iindex[1][x]; } } boolean graftPossible = false; double [] classDist = new double[n.numClasses()]; for(int x = 0; x < n.numInstances(); x++) { if(iindex[1][x] > 0 && !n.instance(x).classIsMissing()) classDist[(int)n.instance(x).classValue()] += iindex[1][x]; } for(int cVal = 0; cVal < n.numClasses(); cVal++) { double theLaplace = (classDist[cVal] + 1.0) / (classDist[cVal] + 2.0); if(cVal != leafClass && (theLaplace > leafLaplace) && (biprob(classDist[cVal], classDist[cVal], leafLaplace) > m_BiProbCrit)) { graftPossible = true; break; } } if(!graftPossible) { return; } // 1. Initialize to {} a set of tuples t containing potential tests ArrayList t = new ArrayList(); // go through each attribute for(int a = 0; a < n.numAttributes(); a++) { if(a == n.classIndex()) continue; // skip the class // sort instances in atbop by $a int [] sorted = sortByAttribute(n, a); // 2. For each continuous attribute $a: if(n.attribute(a).isNumeric()) { // find min and max values for this attribute at the leaf boolean prohibited = false; double minLeaf = Double.POSITIVE_INFINITY; double maxLeaf = Double.NEGATIVE_INFINITY; for(int i = 0; i < l.numInstances(); i++) { if(l.instance(i).isMissing(a)) { if(l.instance(i).classValue() == leafClass) { prohibited = true; break; } } double value = l.instance(i).value(a); if(!m_relabel || l.instance(i).classValue() == leafClass) { if(value < minLeaf) minLeaf = value; if(value > maxLeaf) maxLeaf = value; } } if(prohibited) { continue; } // (a) find values of // $n: instances in atbop (already have that, actually) // $v: a value for $a that exists for a case in the atbop, where // $v is < the min value for $a for a case at the leaf which // has the class $c, and $v is > the lowerlimit of $a at // the leaf. // (note: error in original paper stated that $v must be // smaller OR EQUAL TO the min value). // $k: $k is a class // that maximize L' = Laplace({$x: $x contained in cases($n) // & value($a,$x) <= $v & value($a,$x) > lowerlim($l,$a)}, $k). double minBestClass = Double.NaN; double minBestLaplace = leafLaplace; double minBestVal = Double.NaN; double minBestPos = Double.NaN; double minBestTotal = Double.NaN; double [][] minBestCounts = null; double [][] counts = new double[2][n.numClasses()]; for(int x = 0; x < n.numInstances(); x++) { if(n.instance(sorted[x]).isMissing(a)) break; // missing are sorted to end: no more valid vals double theval = n.instance(sorted[x]).value(a); if(m_Debug) System.out.println("\t " + theval); if(theval <= limits[a][0]) { if(m_Debug) System.out.println("\t <= lowerlim: continuing..."); continue; } // note: error in paper would have this read "theVal > minLeaf) if(theval >= minLeaf) { if(m_Debug) System.out.println("\t >= minLeaf; breaking..."); break; } counts[0][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; if(x != n.numInstances() - 1) { int z = x + 1; while(z < n.numInstances() && n.instance(sorted[z]).value(a) == theval) { z++; x++; counts[0][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; } } // work out the best laplace/class (for <= theval) double total = Utils.sum(counts[0]); for(int c = 0; c < n.numClasses(); c++) { double temp = (counts[0][c]+1.0)/(total+2.0); if(temp > minBestLaplace) { minBestPos = counts[0][c]; minBestTotal = total; minBestLaplace = temp; minBestClass = c; minBestCounts = copyCounts(counts); minBestVal = (x == n.numInstances()-1) ? theval : ((theval + n.instance(sorted[x+1]).value(a)) / 2.0); } } } // (b) add to t tuple <n,a,v,k,L',"<="> if(!Double.isNaN(minBestVal) && biprob(minBestPos, minBestTotal, leafLaplace) > m_BiProbCrit) { GraftSplit gsplit = null; try { gsplit = new GraftSplit(a, minBestVal, 0, leafClass, minBestCounts); } catch (Exception e) { System.err.println("graftsplit error: "+e.getMessage()); System.exit(1); } t.add(gsplit); } // free space minBestCounts = null; // (c) find values of // n: instances in atbop (already have that, actually) // $v: a value for $a that exists for a case in the atbop, where // $v is > the max value for $a for a case at the leaf which // has the class $c, and $v is <= the upperlimit of $a at // the leaf. // k: k is a class // that maximize L' = Laplace({x: x contained in cases(n) // & value(a,x) > v & value(a,x) <= upperlim(l,a)}, k). double maxBestClass = -1; double maxBestLaplace = leafLaplace; double maxBestVal = Double.NaN; double maxBestPos = Double.NaN; double maxBestTotal = Double.NaN; double [][] maxBestCounts = null; for(int c = 0; c < n.numClasses(); c++) { // zero the counts counts[0][c] = 0; counts[1][c] = 0; // shouldn't need to do this ... } // check smallest val for a in atbop is < upper limit if(n.numInstances() >= 1 && n.instance(sorted[0]).value(a) < limits[a][1]) { for(int x = n.numInstances() - 1; x >= 0; x--) { if(n.instance(sorted[x]).isMissing(a)) continue; double theval = n.instance(sorted[x]).value(a); if(m_Debug) System.out.println("\t " + theval); if(theval > limits[a][1]) { if(m_Debug) System.out.println("\t >= upperlim; continuing..."); continue; } if(theval <= maxLeaf) { if(m_Debug) System.out.println("\t < maxLeaf; breaking..."); break; } // increment counts counts[1][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; if(x != 0 && !n.instance(sorted[x-1]).isMissing(a)) { int z = x - 1; while(z >= 0 && n.instance(sorted[z]).value(a) == theval) { z--; x--; counts[1][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; } } // work out best laplace for > theval double total = Utils.sum(counts[1]); for(int c = 0; c < n.numClasses(); c++) { double temp = (counts[1][c]+1.0)/(total+2.0); if(temp > maxBestLaplace ) { maxBestPos = counts[1][c]; maxBestTotal = total; maxBestLaplace = temp; maxBestClass = c; maxBestCounts = copyCounts(counts); maxBestVal = (x == 0) ? theval : ((theval + n.instance(sorted[x-1]).value(a)) / 2.0); } } } // (d) add to t tuple <n,a,v,k,L',">"> if(!Double.isNaN(maxBestVal)
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?