⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logitboost.java

📁 为了下东西 随便发了个 datamining 的源代码
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    LogitBoost.java
 *    Copyright (C) 1999, 2002 Len Trigg, Eibe Frank
 *
 */

package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/**
 * Class for performing additive logistic regression..
 * This class performs classification using a regression scheme as the 
 * base learner, and can handle multi-class problems.  For more
 * information, see<p>
 * 
 * Friedman, J., T. Hastie and R. Tibshirani (1998) <i>Additive Logistic
 * Regression: a Statistical View of Boosting</i> 
 * <a href="ftp://stat.stanford.edu/pub/friedman/boost.ps">download 
 * postscript</a>. <p>
 *
 * Valid options are:<p>
 *
 * -D <br>
 * Turn on debugging output.<p>
 *
 * -W classname <br>
 * Specify the full class name of a weak learner as the basis for 
 * boosting (required).<p>
 *
 * -I num <br>
 * Set the number of boost iterations (default 10). <p>
 *
 * -Q <br>
 * Use resampling instead of reweighting.<p>
 *
 * -S seed <br>
 * Random number seed for resampling (default 1).<p>
 *
 * -P num <br>
 * Set the percentage of weight mass used to build classifiers
 * (default 100). <p>
 *
 * -F num <br>
 * Set number of folds for the internal cross-validation
 * (default 0 -- no cross-validation). <p>
 *
 * -R num <br>
 * Set number of runs for the internal cross-validation
 * (default 1). <p>
 *
 * -L num <br> 
 * Set the threshold for the improvement of the
 * average loglikelihood (default -Double.MAX_VALUE). <p>
 *
 * -H num <br> 
 * Set the value of the shrinkage parameter (default 1). <p>
 *
 * Options after -- are passed to the designated learner.<p>
 *
 * @author Len Trigg (trigg@cs.waikato.ac.nz)
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision$ 
 */
public class LogitBoost extends RandomizableIteratedSingleClassifierEnhancer
  implements Sourcable, WeightedInstancesHandler {

  /** Array for storing the generated base classifiers. 
   Note: we are hiding the variable from IteratedSingleClassifierEnhancer*/
  protected Classifier [][] m_Classifiers;

  /** The number of classes */
  protected int m_NumClasses;

  /** The number of successfully generated base classifiers. */
  protected int m_NumGenerated;

  /** The number of folds for the internal cross-validation. */
  protected int m_NumFolds = 0;

  /** The number of runs for the internal cross-validation. */
  protected int m_NumRuns = 1;

  /** Weight thresholding. The percentage of weight mass used in training */
  protected int m_WeightThreshold = 100;

  /** A threshold for responses (Friedman suggests between 2 and 4) */
  protected static final double Z_MAX = 3;

  /** Dummy dataset with a numeric class */
  protected Instances m_NumericClassData;

  /** The actual class attribute (for getting class names) */
  protected Attribute m_ClassAttribute;

  /** Use boosting with reweighting? */
  protected boolean m_UseResampling;

  /** The threshold on the improvement of the likelihood */   
  protected double m_Precision = -Double.MAX_VALUE;

  /** The value of the shrinkage parameter */
  protected double m_Shrinkage = 1;

  /** The random number generator used */
  protected Random m_RandomInstance = null;

  /** The value by which the actual target value for the
      true class is offset. */
  protected double m_Offset = 0.0;
    
  /**
   * Returns a string describing classifier
   * @return a description suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {

    return "Class for performing additive logistic regression. "
      + "This class performs classification using a regression scheme as the "
      + "base learner, and can handle multi-class problems.  For more "
      + "information, see\n\n"
      + "Friedman, J., T. Hastie and R. Tibshirani (1998) \"Additive Logistic "
      + "Regression: a Statistical View of Boosting\". Technical report. " 
      + "Stanford University.\n\n"
      + "Can do efficient internal cross-validation to determine "
      + "appropriate number of iterations.";
  }
    
  /**
   * Constructor.
   */
  public LogitBoost() {
    
    m_Classifier = new weka.classifiers.trees.DecisionStump();
  }

  /**
   * String describing default classifier.
   */
  protected String defaultClassifierString() {
    
    return "weka.classifiers.trees.DecisionStump";
  }

  /**
   * Select only instances with weights that contribute to 
   * the specified quantile of the weight distribution
   *
   * @param data the input instances
   * @param quantile the specified quantile eg 0.9 to select 
   * 90% of the weight mass
   * @return the selected instances
   */
  protected Instances selectWeightQuantile(Instances data, double quantile) { 

    int numInstances = data.numInstances();
    Instances trainData = new Instances(data, numInstances);
    double [] weights = new double [numInstances];

    double sumOfWeights = 0;
    for (int i = 0; i < numInstances; i++) {
      weights[i] = data.instance(i).weight();
      sumOfWeights += weights[i];
    }
    double weightMassToSelect = sumOfWeights * quantile;
    int [] sortedIndices = Utils.sort(weights);

    // Select the instances
    sumOfWeights = 0;
    for (int i = numInstances-1; i >= 0; i--) {
      Instance instance = (Instance)data.instance(sortedIndices[i]).copy();
      trainData.add(instance);
      sumOfWeights += weights[sortedIndices[i]];
      if ((sumOfWeights > weightMassToSelect) && 
	  (i > 0) && 
	  (weights[sortedIndices[i]] != weights[sortedIndices[i-1]])) {
	break;
      }
    }
    if (m_Debug) {
      System.err.println("Selected " + trainData.numInstances()
			 + " out of " + numInstances);
    }
    return trainData;
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {

    Vector newVector = new Vector(6);

    newVector.addElement(new Option(
	      "\tUse resampling for boosting.",
	      "Q", 0, "-Q"));
    newVector.addElement(new Option(
	      "\tPercentage of weight mass to base training on.\n"
	      +"\t(default 100, reduce to around 90 speed up)",
	      "P", 1, "-P <percent>"));
    newVector.addElement(new Option(
	      "\tNumber of folds for internal cross-validation.\n"
	      +"\t(default 0 -- no cross-validation)",
	      "F", 1, "-F <num>"));
    newVector.addElement(new Option(
	      "\tNumber of runs for internal cross-validation.\n"
	      +"\t(default 1)",
	      "R", 1, "-R <num>"));
    newVector.addElement(new Option(
	      "\tThreshold on the improvement of the likelihood.\n"
	      +"\t(default -Double.MAX_VALUE)",
	      "L", 1, "-L <num>"));
    newVector.addElement(new Option(
	      "\tShrinkage parameter.\n"
	      +"\t(default 1)",
	      "H", 1, "-H <num>"));

    Enumeration em = super.listOptions();
    while (em.hasMoreElements()) {
      newVector.addElement(em.nextElement());
    }
    return newVector.elements();
  }


  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * -D <br>
   * Turn on debugging output.<p>
   *
   * -W classname <br>
   * Specify the full class name of a weak learner as the basis for 
   * boosting (required).<p>
   *
   * -I num <br>
   * Set the number of boost iterations (default 10). <p>
   *
   * -Q <br>
   * Use resampling instead of reweighting.<p>
   *
   * -S seed <br>
   * Random number seed for resampling (default 1).<p>
   *
   * -P num <br>
   * Set the percentage of weight mass used to build classifiers
   * (default 100). <p>
   *
   * -F num <br>
   * Set number of folds for the internal cross-validation
   * (default 0 -- no cross-validation). <p>
   *
   * -R num <br>
   * Set number of runs for the internal cross-validation
   * (default 1. <p>
   *
   * -L num <br> 
   * Set the threshold for the improvement of the
   * average loglikelihood (default -Double.MAX_VALUE). <p>
   *
   * -H num <br> 
   * Set the value of the shrinkage parameter (default 1). <p>
   *
   * Options after -- are passed to the designated learner.<p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    
    String numFolds = Utils.getOption('F', options);
    if (numFolds.length() != 0) {
      setNumFolds(Integer.parseInt(numFolds));
    } else {
      setNumFolds(0);
    }
    
    String numRuns = Utils.getOption('R', options);
    if (numRuns.length() != 0) {
      setNumRuns(Integer.parseInt(numRuns));
    } else {
      setNumRuns(1);
    }

    String thresholdString = Utils.getOption('P', options);
    if (thresholdString.length() != 0) {
      setWeightThreshold(Integer.parseInt(thresholdString));
    } else {
      setWeightThreshold(100);
    }

    String precisionString = Utils.getOption('L', options);
    if (precisionString.length() != 0) {
      setLikelihoodThreshold(new Double(precisionString).
	doubleValue());
    } else {
      setLikelihoodThreshold(-Double.MAX_VALUE);
    }

    String shrinkageString = Utils.getOption('H', options);
    if (shrinkageString.length() != 0) {
      setShrinkage(new Double(shrinkageString).
	doubleValue());
    } else {
      setShrinkage(1.0);
    }

    setUseResampling(Utils.getFlag('Q', options));
    if (m_UseResampling && (thresholdString.length() != 0)) {
      throw new Exception("Weight pruning with resampling"+

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -