c45pruneableclassifiertreeg.java

来自「Weka」· Java 代码 · 共 1,217 行 · 第 1/3 页

JAVA
1,217
字号
               && biprob(maxBestPos,maxBestTotal,leafLaplace) > m_BiProbCrit) {                GraftSplit gsplit = null;                try {                   gsplit = new GraftSplit(a, maxBestVal, 1,                      leafClass, maxBestCounts);                } catch (Exception e) {                   System.err.println("graftsplit error:" + e.getMessage());                   System.exit(1);                }                t.add(gsplit);             }          }       } else {    // must be a nominal attribute          // 3. for each discrete attribute a for which there is no          //    test at an ancestor of l          // skip if this attribute has already been used          if(limits[a][1] == 1) {             continue;          }          boolean [] prohibit = new boolean[l.attribute(a).numValues()];          for(int aval = 0; aval < n.attribute(a).numValues(); aval++) {             for(int x = 0; x < l.numInstances(); x++) {                if((l.instance(x).isMissing(a)                    || l.instance(x).value(a) == aval)                  && (!m_relabel || (l.instance(x).classValue() == leafClass))) {                   prohibit[aval] = true;                   break;                }             }          }          // (a) find values of          //       $n: instances in atbop (already have that, actually)          //       $v: $v is a value for $a          //       $k: $k is a class          //     that maximize L' = Laplace({$x: $x contained in cases($n)          //           & value($a,$x) = $v}, $k).          double bestVal = Double.NaN;          double bestClass = Double.NaN;          double bestLaplace = leafLaplace;          double [][] bestCounts = null;          double [][] counts = new double[2][n.numClasses()];          for(int x = 0; x < n.numInstances(); x++) {             if(n.instance(sorted[x]).isMissing(a))                continue;             // zero the counts             for(int c = 0; c < n.numClasses(); c++)                counts[0][c] = 0;             double theval = n.instance(sorted[x]).value(a);             counts[0][(int)n.instance(sorted[x]).classValue()]                += iindex[1][sorted[x]];             if(x != n.numInstances() - 1) {                int z = x + 1;                while(z < n.numInstances()                  && n.instance(sorted[z]).value(a) == theval) {                   z++; x++;                   counts[0][(int)n.instance(sorted[x]).classValue()]                      += iindex[1][sorted[x]];                }             }             if(!prohibit[(int)theval]) {                // work out best laplace for > theval                double total = Utils.sum(counts[0]);                bestLaplace = leafLaplace;                bestClass = Double.NaN;                for(int c = 0; c < n.numClasses(); c++) {                   double temp = (counts[0][c]+1.0)/(total+2.0);                   if(temp > bestLaplace                    && biprob(counts[0][c],total,leafLaplace) > m_BiProbCrit) {                      bestLaplace = temp;                      bestClass = c;                      bestVal = theval;                      bestCounts = copyCounts(counts);                   }                }		// add to graft list                if(!Double.isNaN(bestClass)) {                   GraftSplit gsplit = null;                   try {                      gsplit = new GraftSplit(a, bestVal, 2,                         leafClass, bestCounts);                   } catch (Exception e) {                     System.err.println("graftsplit error: "+e.getMessage());                     System.exit(1);                   }                   t.add(gsplit);                }             }          }          // (b) add to t tuple <n,a,v,k,L',"=">          // done this already       }    }    // 4. remove from t all tuples <n,a,v,c,L,x> such that L <=    //    Laplace(cases(l),c) or prob(x,n,Laplace(cases(l),c) <= 0.05    //      -- checked this constraint prior to adding a tuple --    // *** step six done before step five for efficiency ***    // 6. for each <n,a,v,k,L,x> in t ordered on L from highest to lowest    // order the tuples from highest to lowest laplace    // (this actually orders lowest to highest)    Collections.sort(t);    // 5. remove from t all tuples <n,a,v,c,L,x> such that there is    //    no tuple <n',a',v',k',L',x'> such that k' != c & L' < L.    for(int x = 0; x < t.size(); x++) {       GraftSplit gs = (GraftSplit)t.get(x);       if(gs.maxClassForSubsetOfInterest() != leafClass) {          break; // reached a graft with class != leafClass, so stop deleting       } else {          t.remove(x);          x--;       }    }    // if no potential grafts were found, do nothing and return    if(t.size() < 1) {       return;    }    // create the distributions for each graft    for(int x = t.size()-1; x >= 0; x--) {       GraftSplit gs = (GraftSplit)t.get(x);       try {          gs.buildClassifier(l);          gs.deleteGraftedCases(l); // so they don't go down the other branch       } catch (Exception e) {          System.err.println("graftsplit build error: " + e.getMessage());       }    }    // add this stuff to the tree    ((C45PruneableClassifierTreeG)parent).setDescendents(t, this);  }  /**   * sorts the int array in ascending order by attribute indexed    * by a in dataset data.     * @param the data the indices represent   * @param the index of the attribute to sort by   * @return array of sorted indicies   */  private int [] sortByAttribute(Instances data, int a) {    double [] attList = data.attributeToDoubleArray(a);    int [] temp = Utils.sort(attList);    return temp;  }  /**   * deep copy the 2d array of counts   *   * @param src the array to copy   * @return a copy of src   */  private double [][] copyCounts(double [][] src) {    double [][] newArr = new double[src.length][0];    for(int x = 0; x < src.length; x++) {       newArr[x] = new double[src[x].length];       for(int y = 0; y < src[x].length; y++) {          newArr[x][y] = src[x][y];       }    }    return newArr;  }    /**   * Help method for computing class probabilities of   * a given instance.   *   * @throws Exception if something goes wrong   */  private double getProbsLaplace(int classIndex, Instance instance, double weight)       throws Exception {    double [] weights;    double prob = 0;    int treeIndex;    int i,j;    if (m_isLeaf) {       return weight * localModel().classProbLaplace(classIndex, instance, -1);    } else {       treeIndex = localModel().whichSubset(instance);       if (treeIndex == -1) {          weights = localModel().weights(instance);          for (i = 0; i < m_sons.length; i++) {             if (!son(i).m_isEmpty) {                if (!son(i).m_isLeaf) {                   prob += son(i).getProbsLaplace(classIndex, instance,                                                  weights[i] * weight);                } else {                   prob += weight * weights[i] *                     localModel().classProbLaplace(classIndex, instance, i);                }             }          }          return prob;       } else {          if (son(treeIndex).m_isLeaf) {             return weight * localModel().classProbLaplace(classIndex, instance,                                                           treeIndex);          } else {             return son(treeIndex).getProbsLaplace(classIndex,instance,weight);          }       }    }  }  /**   * Help method for computing class probabilities of   * a given instance.   *   * @throws Exception if something goes wrong   */  private double getProbs(int classIndex, Instance instance, double weight)      throws Exception {    double [] weights;    double prob = 0;    int treeIndex;    int i,j;    if (m_isLeaf) {       return weight * localModel().classProb(classIndex, instance, -1);    } else {       treeIndex = localModel().whichSubset(instance);       if (treeIndex == -1) {          weights = localModel().weights(instance);          for (i = 0; i < m_sons.length; i++) {             if (!son(i).m_isEmpty) {                prob += son(i).getProbs(classIndex, instance,                                 weights[i] * weight);             }          }          return prob;       } else {          if (son(treeIndex).m_isEmpty) {             return weight * localModel().classProb(classIndex, instance,                                                    treeIndex);          } else {             return son(treeIndex).getProbs(classIndex, instance, weight);          }       }    }  }  /**   * add the grafted nodes at originalLeaf's position in tree.   * a recursive function that terminates when t is empty.   *    * @param t the list of nodes to graft   * @param originalLeaf the leaf that the grafts are replacing   */  public void setDescendents(ArrayList t,                              C45PruneableClassifierTreeG originalLeaf) {    Instances headerInfo = new Instances(m_train, 0);    boolean end = false;    ClassifierSplitModel splitmod = null;    C45PruneableClassifierTreeG newNode;    if(t.size() > 0) {       splitmod = (ClassifierSplitModel)t.remove(t.size() - 1);       newNode = new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo,                           splitmod, m_pruneTheTree, m_CF, m_subtreeRaising,                           false, m_relabel, m_cleanup);    } else {       // get the leaf for one of newNode's children       NoSplit kLeaf = ((GraftSplit)localModel()).getOtherLeaf();       newNode =              new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo,                           kLeaf, m_pruneTheTree, m_CF, m_subtreeRaising,                           true, m_relabel, m_cleanup);       end = true;    }    // behave differently for parent of original leaf, since we don't    // want to destroy any of its other branches    if(m_sons != null) {       for(int x = 0; x < m_sons.length; x++) {          if(son(x).equals(originalLeaf)) {             m_sons[x] = newNode;  // replace originalLeaf with newNode          }       }    } else {       // allocate space for the children       m_sons = new C45PruneableClassifierTreeG[localModel().numSubsets()];        // get the leaf for one of newNode's children       NoSplit kLeaf = ((GraftSplit)localModel()).getLeaf();       C45PruneableClassifierTreeG kNode =                  new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo,                               kLeaf, m_pruneTheTree, m_CF, m_subtreeRaising,                               true, m_relabel, m_cleanup);        // figure where to put the new node       if(((GraftSplit)localModel()).subsetOfInterest() == 0) {          m_sons[0] = kNode;          m_sons[1] = newNode;       } else {          m_sons[0] = newNode;          m_sons[1] = kNode;       }    }    if(!end)       ((C45PruneableClassifierTreeG)newNode).setDescendents                  (t, (C45PruneableClassifierTreeG)originalLeaf);  }  /**   *  class prob with laplace correction (assumes binary class)   */  private double laplaceLeaf(double classIndex) {    double l =  (localModel().distribution().perClass((int)classIndex) + 1.0)               / (localModel().distribution().total() + 2.0);    return l;  }  /**   * Significance test   * @param double  x, double  n, double r.   * @return returns the probability of obtaining x or MORE out of n   * if r proportion of n are positive.   *   * z for normal estimation of binomial probability of obtaining x    * or more out of n, if r proportion of n are positive   */  public double biprob(double x, double n, double r) throws Exception {    return ((((x) - 0.5) - (n) * (r)) / Math.sqrt((n) * (r) * (1.0 - (r))));  }  /**   * Prints tree structure.   */  public String toString() {    try {       StringBuffer text = new StringBuffer();       if(m_isLeaf) {          text.append(": ");          if(m_localModel instanceof GraftSplit)             text.append(((GraftSplit)m_localModel).dumpLabelG(0,m_train));          else             text.append(m_localModel.dumpLabel(0,m_train));       } else          dumpTree(0,text);       text.append("\n\nNumber of Leaves  : \t"+numLeaves()+"\n");       text.append("\nSize of the tree : \t"+numNodes()+"\n");       return text.toString();    } catch (Exception e) {       return "Can't print classification tree.";    }  }  /**   * Help method for printing tree structure.   *   * @throws Exception if something goes wrong   */  protected void dumpTree(int depth,StringBuffer text) throws Exception {    int i,j;    for(i=0;i<m_sons.length;i++) {       text.append("\n");;       for(j=0;j<depth;j++)          text.append("|   ");       text.append(m_localModel.leftSide(m_train));       text.append(m_localModel.rightSide(i, m_train));       if(m_sons[i].m_isLeaf) {          text.append(": ");          if(m_localModel instanceof GraftSplit)             text.append(((GraftSplit)m_localModel).dumpLabelG(i,m_train));          else             text.append(m_localModel.dumpLabel(i,m_train));       } else          ((C45PruneableClassifierTreeG)m_sons[i]).dumpTree(depth+1,text);     }  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?