graftsplit.java

来自「Weka」· Java 代码 · 共 541 行 · 第 1/2 页

JAVA
541
字号
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *  GraftSplit.java *  Copyright (C) 2007 Geoff Webb & Janice Boughton *  a split object for nodes added to a tree during grafting. *  (used in classifier J48g). */package weka.classifiers.trees.j48;import weka.core.*;/** * Class implementing a split for nodes added to a tree during grafting. * * @author Janice Boughton (jrbought@infotech.monash.edu.au) * @version $Revision 1.0 $ */public class GraftSplit extends ClassifierSplitModel implements Comparable {  /** the distribution for graft values, from cases in atbop */  private Distribution m_graftdistro;	  /** the attribute we are splitting on */  private int m_attIndex;  /** value of split point (if numeric attribute) */  private double m_splitPoint;  /** dominant class of the subset specified by m_testType */  private int m_maxClass;  /** dominant class of the subset not specified by m_testType */  private int m_otherLeafMaxClass;  /** laplace value of the subset specified by m_testType for m_maxClass */  private double m_laplace;  /** leaf for the subset specified by m_testType */  private Distribution m_leafdistro;  /**    * type of test:   * 0: <= test   * 1: > test   * 2: = test   * 3: != test   */  private int m_testType;  /**   * constructor   *   * @param a the attribute to split on   * @param v the value of a where split occurs   * @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !)   * @param c the class to label the leaf node pointed to by test as.   * @param l the laplace value (needed when sorting GraftSplits)   */  public GraftSplit(int a, double v, int t, double c, double l) {    m_attIndex = a;    m_splitPoint = v;    m_testType = t;    m_maxClass = (int)c;    m_laplace = l;  }  /**   * constructor   *   * @param a the attribute to split on   * @param v the value of a where split occurs   * @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !=)   * @param oC the class to label the leaf node not pointed to by test as.   * @param counts the distribution for this split   */  public GraftSplit(int a, double v, int t, double oC, double [][] counts)                                                           throws Exception {    m_attIndex = a;    m_splitPoint = v;    m_testType = t;    m_otherLeafMaxClass = (int)oC;    // only deal with binary cuts (<= and >; = and !=)    m_numSubsets = 2;    // which subset are we looking at for the graft?    int subset = subsetOfInterest();  // this is the subset for m_leaf    // create graft distribution, based on counts    m_distribution = new Distribution(counts);    // create a distribution object for m_leaf    double [][] lcounts = new double[1][m_distribution.numClasses()];    for(int c = 0; c < lcounts[0].length; c++) {       lcounts[0][c] = counts[subset][c];    }    m_leafdistro = new Distribution(lcounts);    // set the max class    m_maxClass = m_distribution.maxClass(subset);     // set the laplace value (assumes binary class) for subset of interest    m_laplace = (m_distribution.perClassPerBag(subset, m_maxClass) + 1.0)                / (m_distribution.perBag(subset) + 2.0);  }  /**   * deletes the cases in data that belong to leaf pointed to by   * the test (i.e. the subset of interest).  this is useful so   * the instances belonging to that leaf aren't passed down the   * other branch.   *   * @param data the instances to delete from   */  public void deleteGraftedCases(Instances data) {    int subOfInterest = subsetOfInterest();    for(int x = 0; x < data.numInstances(); x++) {       if(whichSubset(data.instance(x)) == subOfInterest) {          data.delete(x--);       }    }  }  /**   * builds m_graftdistro using the passed data   *   * @param data the instances to use when creating the distribution   */  public void buildClassifier(Instances data) throws Exception {    // distribution for the graft, not counting cases in atbop, only orig leaf    m_graftdistro = new Distribution(2, data.numClasses());     // which subset are we looking at for the graft?    int subset = subsetOfInterest();  // this is the subset for m_leaf    double thisNodeCount = 0;    double knownCases = 0;    boolean allKnown = true;    // populate distribution    for(int x = 0; x < data.numInstances(); x++) {       Instance instance = data.instance(x);       if(instance.isMissing(m_attIndex)) {          allKnown = false;          continue;       }       knownCases += instance.weight();       int subst = whichSubset(instance);       if(subst == -1)          continue;       m_graftdistro.add(subst, instance);       if(subst == subset) {  // instance belongs at m_leaf          thisNodeCount += instance.weight();       }    }    double factor = (knownCases == 0) ? (1.0 / (double)2.0)                                      : (thisNodeCount / knownCases);    if(!allKnown) {       for(int x = 0; x < data.numInstances(); x++) {          if(data.instance(x).isMissing(m_attIndex)) {             Instance instance = data.instance(x);             int subst = whichSubset(instance);             if(subst == -1)                continue;             instance.setWeight(instance.weight() * factor);             m_graftdistro.add(subst, instance);          }       }    }    // if there are no cases at the leaf, make sure the desired    // class is chosen, by setting counts to 0.01    if(m_graftdistro.perBag(subset) == 0) {       double [] counts = new double[data.numClasses()];       counts[m_maxClass] = 0.01;       m_graftdistro.add(subset, counts);    }    if(m_graftdistro.perBag((subset == 0) ? 1 : 0) == 0) {       double [] counts = new double[data.numClasses()];       counts[(int)m_otherLeafMaxClass] = 0.01;       m_graftdistro.add((subset == 0) ? 1 : 0, counts);    }  }  /**   * @return the NoSplit object for the leaf pointed to by m_testType branch   */  public NoSplit getLeaf() {    return new NoSplit(m_leafdistro);  }  /**   * @return the NoSplit object for the leaf not pointed to by m_testType branch   */  public NoSplit getOtherLeaf() {    // the bag (subset) that isn't pointed to by m_testType branch    int bag = (subsetOfInterest() == 0) ? 1 : 0;    double [][] counts = new double[1][m_graftdistro.numClasses()];    double totals = 0;    for(int c = 0; c < counts[0].length; c++) {       counts[0][c] = m_graftdistro.perClassPerBag(bag, c);       totals += counts[0][c];    }    // if empty, make sure proper class gets chosen    if(totals == 0) {       counts[0][m_otherLeafMaxClass] += 0.01;    }    return new NoSplit(new Distribution(counts));  }  /**   * Prints label for subset index of instances (eg class).   *   * @param index the bag to dump label for   * @param data to get attribute names and such   * @return the label as a string   * @exception Exception if something goes wrong   */  public final String dumpLabelG(int index, Instances data) throws Exception {    StringBuffer text;    text = new StringBuffer();    text.append(((Instances)data).classAttribute().       value((index==subsetOfInterest()) ? m_maxClass : m_otherLeafMaxClass));    text.append(" ("+Utils.roundDouble(m_graftdistro.perBag(index),1));    if(Utils.gr(m_graftdistro.numIncorrect(index),0))       text.append("/"        +Utils.roundDouble(m_graftdistro.numIncorrect(index),2));    // show the graft values, only if this is subsetOfInterest()    if(index == subsetOfInterest()) {       text.append("|"+Utils.roundDouble(m_distribution.perBag(index),2));       if(Utils.gr(m_distribution.numIncorrect(index),0))          text.append("/"             +Utils.roundDouble(m_distribution.numIncorrect(index),2));    }    text.append(")");    return text.toString();  }  /**   * @return the subset that is specified by the test type

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?