c45pruneableclassifiertreeg.java

来自「Weka」· Java 代码 · 共 1,217 行 · 第 1/3 页

JAVA
1,217
字号
  }  /**   * Method just exists to make program easier to read.   */  private C45PruneableClassifierTreeG son(int index){    return (C45PruneableClassifierTreeG)m_sons[index];  }  /**   * Initializes variables for grafting.   * sets up limits array (for numeric attributes) and calls    * the recursive function traverseTree.   *   * @param data the data for the tree   * @throws exception if anything goes wrong   */  public void doGrafting(Instances data) throws Exception {    // 2d array for the limits    double [][] limits = new double[data.numAttributes()][2];    // 2nd dimension: index 0 == lower limit, index 1 == upper limit    // initialise to no limit    for(int i = 0; i < data.numAttributes(); i++) {       limits[i][0] = Double.NEGATIVE_INFINITY;       limits[i][1] = Double.POSITIVE_INFINITY;    }    // use an index instead of creating new Insances objects all the time    // instanceIndex[0] == array for weights at leaf    // instanceIndex[1] == array for weights in atbop    double [][] instanceIndex = new double[2][data.numInstances()];    // initialize the weight for each instance    for(int x = 0; x < data.numInstances(); x++) {        instanceIndex[0][x] = 1;        instanceIndex[1][x] = 1;  // leaf instances are in atbop    }    // first call to graft    traverseTree(data, instanceIndex, limits, this, 0, -1);  }  /**   * recursive function.   * if this node is a leaf then calls findGraft, otherwise sorts    * the two sets of instances (tracked in iindex array) and calls   * sortInstances for each of the child nodes (which then calls   * this method).   *   * @param fulldata all instances   * @param iindex array the tracks the weight of each instance in   *        the atbop and at the leaf (0.0 if not present)   * @param limits array specifying current upper/lower limits for numeric atts   * @param parent the node immediately before the current one   * @param pL laplace for node, as calculated by parent (in case leaf is empty)   * @param nodeClass class of node, determined by parent (in case leaf empty)   */  private void traverseTree(Instances fulldata, double [][] iindex,      double[][] limits, C45PruneableClassifierTreeG parent,      double pL, int nodeClass) throws Exception {        if(m_isLeaf) {       findGraft(fulldata, iindex, limits,                  (ClassifierTree)parent, pL, nodeClass);    } else {       // traverse each branch       for(int i = 0; i < localModel().numSubsets(); i++) {          double [][] newiindex = new double[2][fulldata.numInstances()];          for(int x = 0; x < 2; x++)             System.arraycopy(iindex[x], 0, newiindex[x], 0, iindex[x].length);          sortInstances(fulldata, newiindex, limits, i);       }    }  }  /**   * sorts/deletes instances into/from node and atbop according to    * the test for subset, then calls traverseTree for subset's node.   *   * @param fulldata all instances   * @param iindex array the tracks the weight of each instance in   *        the atbop and at the leaf (0.0 if not present)   * @param limits array specifying current upper/lower limits for numeric atts   * @param subset the subset for which to sort instances into inode & iatbop   */  private void sortInstances(Instances fulldata, double [][] iindex,                    double [][] limits, int subset) throws Exception {    C45Split test = (C45Split)localModel();    // update the instances index for subset    double knownCases = 0;    double thisSubsetCount = 0;    for(int x = 0; x < iindex[0].length; x++) {       if(iindex[0][x] == 0 && iindex[1][x] == 0) // skip "discarded" instances          continue;       if(!fulldata.instance(x).isMissing(test.attIndex())) {          knownCases += iindex[0][x];          if(test.whichSubset(fulldata.instance(x)) != subset) {             if(iindex[0][x] > 0) {                // move to atbop, delete from leaf                iindex[1][x] = iindex[0][x];                iindex[0][x] = 0;             } else {                if(iindex[1][x] > 0) {                   // instance is now "discarded"                   iindex[1][x] = 0;                }             }          } else {             thisSubsetCount += iindex[0][x];          }       }    }    // work out proportions of weight for missing values for leaf and atbop    double lprop = (knownCases == 0) ? (1.0 / (double)test.numSubsets())                                 : (thisSubsetCount / (double)knownCases);    // add in the instances that have missing value for attIndex     for(int x = 0; x < iindex[0].length; x++) {       if(iindex[0][x] == 0 && iindex[1][x] == 0)          continue;     // skip "discarded" instances       if(fulldata.instance(x).isMissing(test.attIndex())) {          iindex[1][x] -= (iindex[1][x] - iindex[0][x]) * (1-lprop);          iindex[0][x] *= lprop;       }    }    int nodeClass = localModel().distribution().maxClass(subset);    double pL = (localModel().distribution().perClass(nodeClass) + 1.0)               / (localModel().distribution().total() + 2.0);    // call traerseTree method for the child node    son(subset).traverseTree(fulldata, iindex,          test.minsAndMaxs(fulldata, limits, subset), this, pL, nodeClass);  }  /**   * finds new nodes that improve accuracy and grafts them onto the tree   *   * @param fulldata the instances in whole trainset   * @param iindex records num tests each instance has failed up to this node   * @param limits the upper/lower limits for numeric attributes   * @param parent the node immediately before the current one   * @param pLaplace laplace for leaf, calculated by parent (in case leaf empty)   * @param pLeafClass class of leaf, determined by parent (in case leaf empty)   */  private void findGraft(Instances fulldata, double [][] iindex,    double [][] limits, ClassifierTree parent, double pLaplace,    int pLeafClass) throws Exception {    // get the class for this leaf    int leafClass = (m_isEmpty)                       ? pLeafClass                       :  localModel().distribution().maxClass();    // get the laplace value for this leaf    double leafLaplace = (m_isEmpty)                            ? pLaplace                            : laplaceLeaf(leafClass);    // sort the instances into those at the leaf, those in atbop, and discarded    Instances l = new Instances(fulldata, fulldata.numInstances());    Instances n = new Instances(fulldata, fulldata.numInstances());    int lcount = 0;    int acount = 0;    for(int x = 0; x < fulldata.numInstances(); x++) {       if(iindex[0][x] <= 0 && iindex[1][x] <= 0)          continue;       if(iindex[0][x] != 0) {          l.add(fulldata.instance(x));          l.instance(lcount).setWeight(iindex[0][x]);          // move instance's weight in iindex to same index as in l          iindex[0][lcount++] = iindex[0][x];       }       if(iindex[1][x] > 0) {          n.add(fulldata.instance(x));          n.instance(acount).setWeight(iindex[1][x]);          // move instance's weight in iindex to same index as in n          iindex[1][acount++] = iindex[1][x];       }    }    boolean graftPossible = false;    double [] classDist = new double[n.numClasses()];    for(int x = 0; x < n.numInstances(); x++) {       if(iindex[1][x] > 0 && !n.instance(x).classIsMissing())          classDist[(int)n.instance(x).classValue()] += iindex[1][x];    }    for(int cVal = 0; cVal < n.numClasses(); cVal++) {       double theLaplace = (classDist[cVal] + 1.0) / (classDist[cVal] + 2.0);       if(cVal != leafClass && (theLaplace > leafLaplace) &&         (biprob(classDist[cVal], classDist[cVal], leafLaplace)         > m_BiProbCrit)) {          graftPossible = true;          break;       }    }    if(!graftPossible) {       return;    }    // 1. Initialize to {} a set of tuples t containing potential tests    ArrayList t = new ArrayList();    // go through each attribute    for(int a = 0; a < n.numAttributes(); a++) {       if(a == n.classIndex())          continue;   // skip the class       // sort instances in atbop by $a       int [] sorted = sortByAttribute(n, a);       // 2. For each continuous attribute $a:       if(n.attribute(a).isNumeric()) {          // find min and max values for this attribute at the leaf          boolean prohibited = false;          double minLeaf = Double.POSITIVE_INFINITY;          double maxLeaf = Double.NEGATIVE_INFINITY;          for(int i = 0; i < l.numInstances(); i++) {             if(l.instance(i).isMissing(a)) {                if(l.instance(i).classValue() == leafClass) {                   prohibited = true;                   break;                }             }             double value = l.instance(i).value(a);             if(!m_relabel || l.instance(i).classValue() == leafClass) {                if(value < minLeaf)                   minLeaf = value;                if(value > maxLeaf)                   maxLeaf = value;             }          }          if(prohibited) {             continue;	  }          // (a) find values of          //    $n: instances in atbop (already have that, actually)          //    $v: a value for $a that exists for a case in the atbop, where          //       $v is < the min value for $a for a case at the leaf which          //       has the class $c, and $v is > the lowerlimit of $a at          //       the leaf.          //       (note: error in original paper stated that $v must be          //       smaller OR EQUAL TO the min value).          //    $k: $k is a class          //  that maximize L' = Laplace({$x: $x contained in cases($n)          //    & value($a,$x) <= $v & value($a,$x) > lowerlim($l,$a)}, $k).          double minBestClass = Double.NaN;          double minBestLaplace = leafLaplace;          double minBestVal = Double.NaN;          double minBestPos = Double.NaN;          double minBestTotal = Double.NaN;          double [][] minBestCounts = null;          double [][] counts = new double[2][n.numClasses()];          for(int x = 0; x < n.numInstances(); x++) {             if(n.instance(sorted[x]).isMissing(a))                break;   // missing are sorted to end: no more valid vals             double theval = n.instance(sorted[x]).value(a);             if(m_Debug)                System.out.println("\t " + theval);             if(theval <= limits[a][0]) {                if(m_Debug)                   System.out.println("\t  <= lowerlim: continuing...");                continue;             }             // note: error in paper would have this read "theVal > minLeaf)             if(theval >= minLeaf) {                if(m_Debug)                   System.out.println("\t  >= minLeaf; breaking...");                break;             }             counts[0][(int)n.instance(sorted[x]).classValue()]                += iindex[1][sorted[x]];             if(x != n.numInstances() - 1) {                int z = x + 1;                while(z < n.numInstances()                 && n.instance(sorted[z]).value(a) == theval) {                   z++; x++;                   counts[0][(int)n.instance(sorted[x]).classValue()]                     += iindex[1][sorted[x]];                }             }             // work out the best laplace/class (for <= theval)             double total = Utils.sum(counts[0]);             for(int c = 0; c < n.numClasses(); c++) {                double temp = (counts[0][c]+1.0)/(total+2.0);                if(temp > minBestLaplace) {                   minBestPos = counts[0][c];                   minBestTotal = total;                   minBestLaplace = temp;                   minBestClass = c;                   minBestCounts = copyCounts(counts);                   minBestVal = (x == n.numInstances()-1)                       ? theval                      : ((theval + n.instance(sorted[x+1]).value(a)) / 2.0);                }             }          }          // (b) add to t tuple <n,a,v,k,L',"<=">          if(!Double.isNaN(minBestVal)             && biprob(minBestPos, minBestTotal, leafLaplace) > m_BiProbCrit) {             GraftSplit gsplit = null;             try {                gsplit = new GraftSplit(a, minBestVal, 0,                                        leafClass, minBestCounts);             } catch (Exception e) {                System.err.println("graftsplit error: "+e.getMessage());                System.exit(1);             }             t.add(gsplit);	  }          // free space          minBestCounts = null;          // (c) find values of          //    n: instances in atbop (already have that, actually)          //    $v: a value for $a that exists for a case in the atbop, where          //       $v is > the max value for $a for a case at the leaf which          //       has the class $c, and $v is <= the upperlimit of $a at          //       the leaf.          //    k: k is a class          //   that maximize L' = Laplace({x: x contained in cases(n)          //       & value(a,x) > v & value(a,x) <= upperlim(l,a)}, k).          double maxBestClass = -1;          double maxBestLaplace = leafLaplace;          double maxBestVal = Double.NaN;          double maxBestPos = Double.NaN;          double maxBestTotal = Double.NaN;          double [][] maxBestCounts = null;          for(int c = 0; c < n.numClasses(); c++) {  // zero the counts             counts[0][c] = 0;             counts[1][c] = 0;  // shouldn't need to do this ...          }          // check smallest val for a in atbop is < upper limit          if(n.numInstances() >= 1           && n.instance(sorted[0]).value(a) < limits[a][1]) {             for(int x = n.numInstances() - 1; x >= 0; x--) {                if(n.instance(sorted[x]).isMissing(a))                   continue;                double theval = n.instance(sorted[x]).value(a);                if(m_Debug)                   System.out.println("\t " + theval);                if(theval > limits[a][1]) {                   if(m_Debug)                      System.out.println("\t  >= upperlim; continuing...");                   continue;                }                if(theval <= maxLeaf) {                   if(m_Debug)                      System.out.println("\t  < maxLeaf; breaking...");                   break;                }                // increment counts                counts[1][(int)n.instance(sorted[x]).classValue()]                    += iindex[1][sorted[x]];                if(x != 0 && !n.instance(sorted[x-1]).isMissing(a)) {                   int z = x - 1;                   while(z >= 0 && n.instance(sorted[z]).value(a) == theval) {                      z--; x--;                      counts[1][(int)n.instance(sorted[x]).classValue()]                         += iindex[1][sorted[x]];                   }                }                // work out best laplace for > theval                double total = Utils.sum(counts[1]);                for(int c = 0; c < n.numClasses(); c++) {                   double temp = (counts[1][c]+1.0)/(total+2.0);                   if(temp > maxBestLaplace ) {                      maxBestPos = counts[1][c];                      maxBestTotal = total;                      maxBestLaplace = temp;                      maxBestClass = c;                      maxBestCounts = copyCounts(counts);                      maxBestVal = (x == 0)                         ? theval                        : ((theval + n.instance(sorted[x-1]).value(a)) / 2.0);                   }                }             }             // (d) add to t tuple <n,a,v,k,L',">">             if(!Double.isNaN(maxBestVal)

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?