⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 racedincrementallogitboost.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    RacedIncrementalLogitBoost.java
 *    Copyright (C) 2002 Richard Kirkby, Eibe Frank
 *
 */

package weka.classifiers.meta;

import weka.classifiers.*;
import weka.classifiers.rules.ZeroR;
import weka.core.*;
import java.util.*;
import java.io.Serializable;

/**
 * Classifier for incremental learning of large datasets by way of racing logit-boosted committees. 
 *
 * Valid options are:<p>
 *
 * -C num <br>
 * Set the minimum chunk size (default 500). <p>
 *
 * -M num <br>
 * Set the maximum chunk size (default 2000). <p>
 *
 * -V num <br>
 * Set the validation set size (default 1000). <p>
 *
 * -D <br>
 * Turn on debugging output.<p>
 *
 * -W classname <br>
 * Specify the full class name of a weak learner as the basis for 
 * boosting (required).<p>
 *
 * -Q <br>
 * Use resampling instead of reweighting.<p>
 *
 * -S seed <br>
 * Random number seed for resampling (default 1).<p>
 *
 * -P type <br>
 * The type of pruning to use. <p>
 *
 * Options after -- are passed to the designated learner.<p>
 *
 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision: 1.1 $ 
 */
public class RacedIncrementalLogitBoost extends RandomizableSingleClassifierEnhancer
  implements UpdateableClassifier {

  /** The pruning types */
  public static final int PRUNETYPE_NONE = 0;
  public static final int PRUNETYPE_LOGLIKELIHOOD = 1;
  public static final Tag [] TAGS_PRUNETYPE = {
    new Tag(PRUNETYPE_NONE, "No pruning"),
    new Tag(PRUNETYPE_LOGLIKELIHOOD, "Log likelihood pruning")
  };

  /** The committees */   
  protected FastVector m_committees;

  /** The pruning type used */
  protected int m_PruningType = PRUNETYPE_LOGLIKELIHOOD;

  /** Whether to use resampling */
  protected boolean m_UseResampling = false;

  /** The number of classes */
  protected int m_NumClasses;

  /** A threshold for responses (Friedman suggests between 2 and 4) */
  protected static final double Z_MAX = 4;

  /** Dummy dataset with a numeric class */
  protected Instances m_NumericClassData;

  /** The actual class attribute (for getting class names) */
  protected Attribute m_ClassAttribute;  

  /** The minimum chunk size used for training */
  protected int m_minChunkSize = 500;

  /** The maimum chunk size used for training */
  protected int m_maxChunkSize = 2000;

  /** The size of the validation set */
  protected int m_validationChunkSize = 1000;

  /** The number of instances consumed */  
  protected int m_numInstancesConsumed;

  /** The instances used for validation */    
  protected Instances m_validationSet;

  /** The instances currently in memory for training */   
  protected Instances m_currentSet;

  /** The current best committee */   
  protected Committee m_bestCommittee;

  /** The default scheme used when committees aren't ready */    
  protected ZeroR m_zeroR = null;

  /** Whether the validation set has recently been changed */ 
  protected boolean m_validationSetChanged;

  /** The maximum number of instances required for processing */   
  protected int m_maxBatchSizeRequired;

  /** The random number generator used */
  protected Random m_RandomInstance = null;

    
  /**
   * Constructor.
   */
  public RacedIncrementalLogitBoost() {
    
    m_Classifier = new weka.classifiers.trees.DecisionStump();
  }

  /**
   * String describing default classifier.
   */
  protected String defaultClassifierString() {
    
    return "weka.classifiers.trees.DecisionStump";
  }


  /* Class representing a committee of LogitBoosted models */
  protected class Committee implements Serializable {

    protected int m_chunkSize;
    protected int m_instancesConsumed; // number eaten from m_currentSet
    protected FastVector m_models;
    protected double m_lastValidationError;
    protected double m_lastLogLikelihood;
    protected boolean m_modelHasChanged;
    protected boolean m_modelHasChangedLL;
    protected double[][] m_validationFs;
    protected double[][] m_newValidationFs;

    /* constructor */
    public Committee(int chunkSize) {

      m_chunkSize = chunkSize;
      m_instancesConsumed = 0;
      m_models = new FastVector();
      m_lastValidationError = 1.0;
      m_lastLogLikelihood = Double.MAX_VALUE;
      m_modelHasChanged = true;
      m_modelHasChangedLL = true;
      m_validationFs = new double[m_validationChunkSize][m_NumClasses];
      m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
    } 

    /* update the committee */
    public boolean update() throws Exception {

      boolean hasChanged = false;
      while (m_currentSet.numInstances() - m_instancesConsumed >= m_chunkSize) {
	Classifier[] newModel = boost(new Instances(m_currentSet, m_instancesConsumed, m_chunkSize));
	for (int i=0; i<m_validationSet.numInstances(); i++) {
	  m_newValidationFs[i] = updateFS(m_validationSet.instance(i), newModel, m_validationFs[i]);
	}
	m_models.addElement(newModel);
	m_instancesConsumed += m_chunkSize;
	hasChanged = true;
      }
      if (hasChanged) {
	m_modelHasChanged = true;
	m_modelHasChangedLL = true;
      }
      return hasChanged;
    }

    /* reset consumation counts */
    public void resetConsumed() {

      m_instancesConsumed = 0;
    }

    /* remove the last model from the committee */
    public void pruneLastModel() {

      if (m_models.size() > 0) {
	m_models.removeElementAt(m_models.size()-1);
	m_modelHasChanged = true;
	m_modelHasChangedLL = true;
      }
    }

    /* decide to keep the last model in the committee */    
    public void keepLastModel() throws Exception {

      m_validationFs = m_newValidationFs;
      m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
      m_modelHasChanged = true;
      m_modelHasChangedLL = true;
    }

    /* calculate the log likelihood on the validation data */        
    public double logLikelihood() throws Exception {

      if (m_modelHasChangedLL) {

	Instance inst;
	double llsum = 0.0;
	for (int i=0; i<m_validationSet.numInstances(); i++) {
	  inst = m_validationSet.instance(i);
	  llsum += (logLikelihood(m_validationFs[i],(int) inst.classValue()));
	}
	m_lastLogLikelihood = llsum / (double) m_validationSet.numInstances();
	m_modelHasChangedLL = false;
      }
      return m_lastLogLikelihood;
    }

    /* calculate the log likelihood on the validation data after adding the last model */    
    public double logLikelihoodAfter() throws Exception {

	Instance inst;
	double llsum = 0.0;
	for (int i=0; i<m_validationSet.numInstances(); i++) {
	  inst = m_validationSet.instance(i);
	  llsum += (logLikelihood(m_newValidationFs[i],(int) inst.classValue()));
	}
	return llsum / (double) m_validationSet.numInstances();
    }

    
    /* calculates the log likelihood of an instance */
    private double logLikelihood(double[] Fs, int classIndex) throws Exception {

      return -Math.log(distributionForInstance(Fs)[classIndex]);
    }

    /* calculates the validation error of the committee */
    public double validationError() throws Exception {

      if (m_modelHasChanged) {

	Instance inst;
	int numIncorrect = 0;
	for (int i=0; i<m_validationSet.numInstances(); i++) {
	  inst = m_validationSet.instance(i);
	  if (classifyInstance(m_validationFs[i]) != inst.classValue())
	    numIncorrect++;
	}
	m_lastValidationError = (double) numIncorrect / (double) m_validationSet.numInstances();
	m_modelHasChanged = false;
      }
      return m_lastValidationError;
    }

    /* returns the chunk size used by the committee */
    public int chunkSize() {

      return m_chunkSize;
    }

    /* returns the number of models in the committee */
    public int committeeSize() {

      return m_models.size();
    }

    
    /* classifies an instance (given Fs values) with the committee */
    public double classifyInstance(double[] Fs) throws Exception {
      
      double [] dist = distributionForInstance(Fs);

      double max = 0;
      int maxIndex = 0;
      
      for (int i = 0; i < dist.length; i++) {
	if (dist[i] > max) {
	  maxIndex = i;
	  max = dist[i];
	}
      }
      if (max > 0) {
	return maxIndex;
      } else {
	return Instance.missingValue();
      }
    }

    /* classifies an instance with the committee */    
    public double classifyInstance(Instance instance) throws Exception {
      
      double [] dist = distributionForInstance(instance);
      switch (instance.classAttribute().type()) {
      case Attribute.NOMINAL:
	double max = 0;
	int maxIndex = 0;
	
	for (int i = 0; i < dist.length; i++) {
	  if (dist[i] > max) {
	    maxIndex = i;
	    max = dist[i];
	  }
	}
	if (max > 0) {
	  return maxIndex;
	} else {
	  return Instance.missingValue();
	}
      case Attribute.NUMERIC:
	return dist[0];
      default:
	return Instance.missingValue();
      }
    }

    /* returns the distribution the committee generates for an instance (given Fs values) */
    public double[] distributionForInstance(double[] Fs) throws Exception {
      
      double [] distribution = new double [m_NumClasses];
      for (int j = 0; j < m_NumClasses; j++) {
	distribution[j] = RtoP(Fs, j);
      }
      return distribution;
    }
    
    /* updates the Fs values given a new model in the committee */
    public double[] updateFS(Instance instance, Classifier[] newModel, double[] Fs) throws Exception {
      
      instance = (Instance)instance.copy();
      instance.setDataset(m_NumericClassData);
      
      double [] Fi = new double [m_NumClasses];
      double Fsum = 0;
      for (int j = 0; j < m_NumClasses; j++) {
	Fi[j] = newModel[j].classifyInstance(instance);
	Fsum += Fi[j];
      }
      Fsum /= m_NumClasses;
      
      double[] newFs = new double[Fs.length];
      for (int j = 0; j < m_NumClasses; j++) {
	newFs[j] = Fs[j] + ((Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses);
      }
      return newFs;
    }

    /* returns the distribution the committee generates for an instance */
    public double[] distributionForInstance(Instance instance) throws Exception {

      instance = (Instance)instance.copy();
      instance.setDataset(m_NumericClassData);
      double [] Fs = new double [m_NumClasses]; 
      for (int i = 0; i < m_models.size(); i++) {
	double [] Fi = new double [m_NumClasses];
	double Fsum = 0;
	Classifier[] model = (Classifier[]) m_models.elementAt(i);
	for (int j = 0; j < m_NumClasses; j++) {
	  Fi[j] = model[j].classifyInstance(instance);
	  Fsum += Fi[j];
	}
	Fsum /= m_NumClasses;
	for (int j = 0; j < m_NumClasses; j++) {
	  Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses;
	}
      }
      double [] distribution = new double [m_NumClasses];
      for (int j = 0; j < m_NumClasses; j++) {
	distribution[j] = RtoP(Fs, j);
      }
      return distribution;
    }

    /* performs a boosting iteration, returning a new model for the committee */
    protected Classifier[] boost(Instances data) throws Exception {
      
      Classifier[] newModel = Classifier.makeCopies(m_Classifier, m_NumClasses);
      
      // Create a copy of the data with the class transformed into numeric
      Instances boostData = new Instances(data);
      boostData.deleteWithMissingClass();
      int numInstances = boostData.numInstances();
      
      // Temporarily unset the class index
      int classIndex = data.classIndex();
      boostData.setClassIndex(-1);
      boostData.deleteAttributeAt(classIndex);
      boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
      boostData.setClassIndex(classIndex);
      double [][] trainFs = new double [numInstances][m_NumClasses];
      double [][] trainYs = new double [numInstances][m_NumClasses];
      for (int j = 0; j < m_NumClasses; j++) {
	for (int i = 0, k = 0; i < numInstances; i++, k++) {
	  while (data.instance(k).classIsMissing()) k++;
	  trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0;
	}
      }
      
      // Evaluate / increment trainFs from the classifiers
      for (int x = 0; x < m_models.size(); x++) {
	for (int i = 0; i < numInstances; i++) {
	  double [] pred = new double [m_NumClasses];
	  double predSum = 0;
	  Classifier[] model = (Classifier[]) m_models.elementAt(x);
	  for (int j = 0; j < m_NumClasses; j++) {
	    pred[j] = model[j].classifyInstance(boostData.instance(i));
	    predSum += pred[j];
	  }
	  predSum /= m_NumClasses;
	  for (int j = 0; j < m_NumClasses; j++) {
	    trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) 
	      / m_NumClasses;
	  }
	}
      }

      for (int j = 0; j < m_NumClasses; j++) {
	
	// Set instance pseudoclass and weights
	for (int i = 0; i < numInstances; i++) {
	  double p = RtoP(trainFs[i], j);
	  Instance current = boostData.instance(i);
	  double z, actual = trainYs[i][j];
	  if (actual == 1) {
	    z = 1.0 / p;
	    if (z > Z_MAX) { // threshold
	      z = Z_MAX;
	    }
	  } else if (actual == 0) {
	    z = -1.0 / (1.0 - p);
	    if (z < -Z_MAX) { // threshold
	      z = -Z_MAX;
	    }
	  } else {
	    z = (actual - p) / (p * (1 - p));
	  }

	  double w = (actual - p) / z;
	  current.setValue(classIndex, z);
	  current.setWeight(numInstances * w);
	}
	
	Instances trainData = boostData;
	if (m_UseResampling) {
	  double[] weights = new double[boostData.numInstances()];
	  for (int kk = 0; kk < weights.length; kk++) {
	    weights[kk] = boostData.instance(kk).weight();
	  }
	  trainData = boostData.resampleWithWeights(m_RandomInstance, 
						    weights);
	}
	
	// Build the classifier
	newModel[j].buildClassifier(trainData);
      }      
      
      return newModel;
    }

    /* outputs description of the committee */
    public String toString() {
      
      StringBuffer text = new StringBuffer();
      
      text.append("RacedIncrementalLogitBoost: Best committee on validation data\n");
      text.append("Base classifiers: \n");
      
      for (int i = 0; i < m_models.size(); i++) {
	text.append("\nModel "+(i+1));
	Classifier[] cModels = (Classifier[]) m_models.elementAt(i);
	for (int j = 0; j < m_NumClasses; j++) {
	  text.append("\n\tClass " + (j + 1) 
		      + " (" + m_ClassAttribute.name() 
		      + "=" + m_ClassAttribute.value(j) + ")\n\n"
		      + cModels[j].toString() + "\n");
	}
      }
      text.append("Number of models: " +
		  m_models.size() + "\n");      
      text.append("Chunk size per model: " + m_chunkSize + "\n");
      
      return text.toString();
    }
  }

 /**
   * Builds the classifier.
   *
   * @param instances the instances to train the classifier with
   * @exception Exception if something goes wrong
   */
  public void buildClassifier(Instances data) throws Exception {

    m_RandomInstance = new Random(m_Seed);

    Instances boostData;
    int classIndex = data.classIndex();

    if (data.classAttribute().isNumeric()) {
      throw new Exception("RacedIncrementalLogitBoost can't handle a numeric class!");
    }
    if (m_Classifier == null) {
      throw new Exception("A base classifier has not been specified!");
    }

    if (!(m_Classifier instanceof WeightedInstancesHandler) &&
	!m_UseResampling) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -