⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 racedincrementallogitboost.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    RacedIncrementalLogitBoost.java
 *    Copyright (C) 2002 Richard Kirkby, Eibe Frank
 *
 */

package weka.classifiers.meta;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/**
 * Classifier for incremental learning of large datasets by way of racing logit-boosted committees. 
 *
 * Valid options are:<p>
 *
 * -C num <br>
 * Set the minimum chunk size (default 500). <p>
 *
 * -M num <br>
 * Set the maximum chunk size (default 8000). <p>
 *
 * -V num <br>
 * Set the validation set size (default 5000). <p>
 *
 * -D <br>
 * Turn on debugging output.<p>
 *
 * -W classname <br>
 * Specify the full class name of a weak learner as the basis for 
 * boosting (required).<p>
 *
 * -Q <br>
 * Use resampling instead of reweighting.<p>
 *
 * -S seed <br>
 * Random number seed for resampling (default 1).<p>
 *
 * -P type <br>
 * The type of pruning to use. <p>
 *
 * Options after -- are passed to the designated learner.<p>
 *
 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision$ 
 */
public class RacedIncrementalLogitBoost extends Classifier
  implements OptionHandler, UpdateableClassifier {

  /** The pruning types */
  public static final int PRUNETYPE_NONE = 0;
  public static final int PRUNETYPE_LOGLIKELIHOOD = 1;
  public static final Tag [] TAGS_PRUNETYPE = {
    new Tag(PRUNETYPE_NONE, "No pruning"),
    new Tag(PRUNETYPE_LOGLIKELIHOOD, "Log likelihood pruning")
  };

  /** The model base classifier to use */
  protected Classifier m_Classifier = new weka.classifiers.trees.DecisionStump();

  /** The committees */   
  protected FastVector m_committees;

  /** The pruning type used */
  protected int m_PruningType = PRUNETYPE_LOGLIKELIHOOD;

  /** Whether to use resampling */
  protected boolean m_UseResampling = false;

  /** Seed for boosting with resampling. */
  protected int m_Seed = 1;

  /** The number of classes */
  protected int m_NumClasses;

  /** A threshold for responses (Friedman suggests between 2 and 4) */
  protected static final double Z_MAX = 4;

  /** Dummy dataset with a numeric class */
  protected Instances m_NumericClassData;

  /** The actual class attribute (for getting class names) */
  protected Attribute m_ClassAttribute;  

  /** The minimum chunk size used for training */
  protected int m_minChunkSize = 500;

  /** The maimum chunk size used for training */
  protected int m_maxChunkSize = 8000;

  /** The size of the validation set */
  protected int m_validationChunkSize = 5000;

  /** The number of instances consumed */  
  protected int m_numInstancesConsumed;

  /** The instances used for validation */    
  protected Instances m_validationSet;

  /** The instances currently in memory for training */   
  protected Instances m_currentSet;

  /** The current best committee */   
  protected Committee m_bestCommittee;

  /** The default scheme used when committees aren't ready */    
  protected ZeroR m_zeroR = null;

  /** Whether the validation set has recently been changed */ 
  protected boolean m_validationSetChanged;

  /** The maximum number of instances required for processing */   
  protected int m_maxBatchSizeRequired;

  /** Whether to output debug messages */     
  protected boolean m_Debug = false;

  /** The random number generator used */
  protected Random m_RandomInstance = null;


  /* Class representing a committee of LogitBoosted models */
  protected class Committee implements Serializable {

    protected int m_chunkSize;
    protected int m_instancesConsumed; // number eaten from m_currentSet
    protected FastVector m_models;
    protected double m_lastValidationError;
    protected double m_lastLogLikelihood;
    protected boolean m_modelHasChanged;
    protected boolean m_modelHasChangedLL;
    protected double[][] m_validationFs;
    protected double[][] m_newValidationFs;

    /* constructor */
    public Committee(int chunkSize) {

      m_chunkSize = chunkSize;
      m_instancesConsumed = 0;
      m_models = new FastVector();
      m_lastValidationError = 1.0;
      m_lastLogLikelihood = Double.MAX_VALUE;
      m_modelHasChanged = true;
      m_modelHasChangedLL = true;
      m_validationFs = new double[m_validationChunkSize][m_NumClasses];
      m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
    } 

    /* update the committee */
    public boolean update() throws Exception {

      boolean hasChanged = false;
      while (m_currentSet.numInstances() - m_instancesConsumed >= m_chunkSize) {
	Classifier[] newModel = boost(new Instances(m_currentSet, m_instancesConsumed, m_chunkSize));
	for (int i=0; i<m_validationSet.numInstances(); i++) {
	  m_newValidationFs[i] = updateFS(m_validationSet.instance(i), newModel, m_validationFs[i]);
	}
	m_models.addElement(newModel);
	m_instancesConsumed += m_chunkSize;
	hasChanged = true;
      }
      if (hasChanged) {
	m_modelHasChanged = true;
	m_modelHasChangedLL = true;
      }
      return hasChanged;
    }

    /* reset consumation counts */
    public void resetConsumed() {

      m_instancesConsumed = 0;
    }

    /* remove the last model from the committee */
    public void pruneLastModel() {

      if (m_models.size() > 0) {
	m_models.removeElementAt(m_models.size()-1);
	m_modelHasChanged = true;
	m_modelHasChangedLL = true;
      }
    }

    /* decide to keep the last model in the committee */    
    public void keepLastModel() throws Exception {

      m_validationFs = m_newValidationFs;
      m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
      m_modelHasChanged = true;
      m_modelHasChangedLL = true;
    }

    /* calculate the log likelihood on the validation data */        
    public double logLikelihood() throws Exception {

      if (m_modelHasChangedLL) {

	Instance inst;
	double llsum = 0.0;
	for (int i=0; i<m_validationSet.numInstances(); i++) {
	  inst = m_validationSet.instance(i);
	  llsum += (logLikelihood(m_validationFs[i],(int) inst.classValue()));
	}
	m_lastLogLikelihood = llsum / (double) m_validationSet.numInstances();
	m_modelHasChangedLL = false;
      }
      return m_lastLogLikelihood;
    }

    /* calculate the log likelihood on the validation data after adding the last model */    
    public double logLikelihoodAfter() throws Exception {

	Instance inst;
	double llsum = 0.0;
	for (int i=0; i<m_validationSet.numInstances(); i++) {
	  inst = m_validationSet.instance(i);
	  llsum += (logLikelihood(m_newValidationFs[i],(int) inst.classValue()));
	}
	return llsum / (double) m_validationSet.numInstances();
    }

    
    /* calculates the log likelihood of an instance */
    private double logLikelihood(double[] Fs, int classIndex) throws Exception {

      return -Math.log(distributionForInstance(Fs)[classIndex]);
    }

    /* calculates the validation error of the committee */
    public double validationError() throws Exception {

      if (m_modelHasChanged) {

	Instance inst;
	int numIncorrect = 0;
	for (int i=0; i<m_validationSet.numInstances(); i++) {
	  inst = m_validationSet.instance(i);
	  if (classifyInstance(m_validationFs[i]) != inst.classValue())
	    numIncorrect++;
	}
	m_lastValidationError = (double) numIncorrect / (double) m_validationSet.numInstances();
	m_modelHasChanged = false;
      }
      return m_lastValidationError;
    }

    /* returns the chunk size used by the committee */
    public int chunkSize() {

      return m_chunkSize;
    }

    /* returns the number of models in the committee */
    public int committeeSize() {

      return m_models.size();
    }

    
    /* classifies an instance (given Fs values) with the committee */
    public double classifyInstance(double[] Fs) throws Exception {
      
      double [] dist = distributionForInstance(Fs);

      double max = 0;
      int maxIndex = 0;
      
      for (int i = 0; i < dist.length; i++) {
	if (dist[i] > max) {
	  maxIndex = i;
	  max = dist[i];
	}
      }
      if (max > 0) {
	return maxIndex;
      } else {
	return Instance.missingValue();
      }
    }

    /* classifies an instance with the committee */    
    public double classifyInstance(Instance instance) throws Exception {
      
      double [] dist = distributionForInstance(instance);
      switch (instance.classAttribute().type()) {
      case Attribute.NOMINAL:
	double max = 0;
	int maxIndex = 0;
	
	for (int i = 0; i < dist.length; i++) {
	  if (dist[i] > max) {
	    maxIndex = i;
	    max = dist[i];
	  }
	}
	if (max > 0) {
	  return maxIndex;
	} else {
	  return Instance.missingValue();
	}
      case Attribute.NUMERIC:
	return dist[0];
      default:
	return Instance.missingValue();
      }
    }

    /* returns the distribution the committee generates for an instance (given Fs values) */
    public double[] distributionForInstance(double[] Fs) throws Exception {
      
      double [] distribution = new double [m_NumClasses];
      for (int j = 0; j < m_NumClasses; j++) {
	distribution[j] = RtoP(Fs, j);
      }
      return distribution;
    }
    
    /* updates the Fs values given a new model in the committee */
    public double[] updateFS(Instance instance, Classifier[] newModel, double[] Fs) throws Exception {
      
      instance = (Instance)instance.copy();
      instance.setDataset(m_NumericClassData);
      
      double [] Fi = new double [m_NumClasses];
      double Fsum = 0;
      for (int j = 0; j < m_NumClasses; j++) {
	Fi[j] = newModel[j].classifyInstance(instance);
	Fsum += Fi[j];
      }
      Fsum /= m_NumClasses;
      
      double[] newFs = new double[Fs.length];
      for (int j = 0; j < m_NumClasses; j++) {
	newFs[j] = Fs[j] + ((Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses);
      }
      return newFs;
    }

    /* returns the distribution the committee generates for an instance */
    public double[] distributionForInstance(Instance instance) throws Exception {

      instance = (Instance)instance.copy();
      instance.setDataset(m_NumericClassData);
      double [] Fs = new double [m_NumClasses]; 
      for (int i = 0; i < m_models.size(); i++) {
	double [] Fi = new double [m_NumClasses];
	double Fsum = 0;
	Classifier[] model = (Classifier[]) m_models.elementAt(i);
	for (int j = 0; j < m_NumClasses; j++) {
	  Fi[j] = model[j].classifyInstance(instance);
	  Fsum += Fi[j];
	}
	Fsum /= m_NumClasses;
	for (int j = 0; j < m_NumClasses; j++) {
	  Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses;
	}
      }
      double [] distribution = new double [m_NumClasses];
      for (int j = 0; j < m_NumClasses; j++) {
	distribution[j] = RtoP(Fs, j);
      }
      return distribution;
    }

    /* performs a boosting iteration, returning a new model for the committee */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -