📄 racedincrementallogitboost.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * RacedIncrementalLogitBoost.java * Copyright (C) 2002 Richard Kirkby, Eibe Frank * */package weka.classifiers.meta;import weka.classifiers.*;import weka.classifiers.rules.ZeroR;import weka.core.*;import java.util.*;import java.io.Serializable;/** * Classifier for incremental learning of large datasets by way of racing logit-boosted committees. * * Valid options are:<p> * * -C num <br> * Set the minimum chunk size (default 500). <p> * * -M num <br> * Set the maximum chunk size (default 8000). <p> * * -V num <br> * Set the validation set size (default 5000). <p> * * -D <br> * Turn on debugging output.<p> * * -W classname <br> * Specify the full class name of a weak learner as the basis for * boosting (required).<p> * * -Q <br> * Use resampling instead of reweighting.<p> * * -S seed <br> * Random number seed for resampling (default 1).<p> * * -P type <br> * The type of pruning to use. <p> * * Options after -- are passed to the designated learner.<p> * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class RacedIncrementalLogitBoost extends DistributionClassifier implements OptionHandler, UpdateableClassifier { /** The pruning types */ public static final int PRUNETYPE_NONE = 0; public static final int PRUNETYPE_LOGLIKELIHOOD = 1; public static final Tag [] TAGS_PRUNETYPE = { new Tag(PRUNETYPE_NONE, "No pruning"), new Tag(PRUNETYPE_LOGLIKELIHOOD, "Log likelihood pruning") }; /** The model base classifier to use */ protected Classifier m_Classifier = new weka.classifiers.trees.DecisionStump(); /** The committees */ protected FastVector m_committees; /** The pruning type used */ protected int m_PruningType = PRUNETYPE_LOGLIKELIHOOD; /** Whether to use resampling */ protected boolean m_UseResampling = false; /** Seed for boosting with resampling. */ protected int m_Seed = 1; /** The number of classes */ protected int m_NumClasses; /** A threshold for responses (Friedman suggests between 2 and 4) */ protected static final double Z_MAX = 4; /** Dummy dataset with a numeric class */ protected Instances m_NumericClassData; /** The actual class attribute (for getting class names) */ protected Attribute m_ClassAttribute; /** The minimum chunk size used for training */ protected int m_minChunkSize = 500; /** The maimum chunk size used for training */ protected int m_maxChunkSize = 8000; /** The size of the validation set */ protected int m_validationChunkSize = 5000; /** The number of instances consumed */ protected int m_numInstancesConsumed; /** The instances used for validation */ protected Instances m_validationSet; /** The instances currently in memory for training */ protected Instances m_currentSet; /** The current best committee */ protected Committee m_bestCommittee; /** The default scheme used when committees aren't ready */ protected ZeroR m_zeroR = null; /** Whether the validation set has recently been changed */ protected boolean m_validationSetChanged; /** The maximum number of instances required for processing */ protected int m_maxBatchSizeRequired; /** Whether to output debug messages */ protected boolean m_Debug = false; /** The random number generator used */ protected Random m_RandomInstance = null; /* Class representing a committee of LogitBoosted models */ protected class Committee implements Serializable { protected int m_chunkSize; protected int m_instancesConsumed; // number eaten from m_currentSet protected FastVector m_models; protected double m_lastValidationError; protected double m_lastLogLikelihood; protected boolean m_modelHasChanged; protected boolean m_modelHasChangedLL; protected double[][] m_validationFs; protected double[][] m_newValidationFs; /* constructor */ public Committee(int chunkSize) { m_chunkSize = chunkSize; m_instancesConsumed = 0; m_models = new FastVector(); m_lastValidationError = 1.0; m_lastLogLikelihood = Double.MAX_VALUE; m_modelHasChanged = true; m_modelHasChangedLL = true; m_validationFs = new double[m_validationChunkSize][m_NumClasses]; m_newValidationFs = new double[m_validationChunkSize][m_NumClasses]; } /* update the committee */ public boolean update() throws Exception { boolean hasChanged = false; while (m_currentSet.numInstances() - m_instancesConsumed >= m_chunkSize) { Classifier[] newModel = boost(new Instances(m_currentSet, m_instancesConsumed, m_chunkSize)); for (int i=0; i<m_validationSet.numInstances(); i++) { m_newValidationFs[i] = updateFS(m_validationSet.instance(i), newModel, m_validationFs[i]); } m_models.addElement(newModel); m_instancesConsumed += m_chunkSize; hasChanged = true; } if (hasChanged) { m_modelHasChanged = true; m_modelHasChangedLL = true; } return hasChanged; } /* reset consumation counts */ public void resetConsumed() { m_instancesConsumed = 0; } /* remove the last model from the committee */ public void pruneLastModel() { if (m_models.size() > 0) { m_models.removeElementAt(m_models.size()-1); m_modelHasChanged = true; m_modelHasChangedLL = true; } } /* decide to keep the last model in the committee */ public void keepLastModel() throws Exception { m_validationFs = m_newValidationFs; m_newValidationFs = new double[m_validationChunkSize][m_NumClasses]; m_modelHasChanged = true; m_modelHasChangedLL = true; } /* calculate the log likelihood on the validation data */ public double logLikelihood() throws Exception { if (m_modelHasChangedLL) { Instance inst; double llsum = 0.0; for (int i=0; i<m_validationSet.numInstances(); i++) { inst = m_validationSet.instance(i); llsum += (logLikelihood(m_validationFs[i],(int) inst.classValue())); } m_lastLogLikelihood = llsum / (double) m_validationSet.numInstances(); m_modelHasChangedLL = false; } return m_lastLogLikelihood; } /* calculate the log likelihood on the validation data after adding the last model */ public double logLikelihoodAfter() throws Exception { Instance inst; double llsum = 0.0; for (int i=0; i<m_validationSet.numInstances(); i++) { inst = m_validationSet.instance(i); llsum += (logLikelihood(m_newValidationFs[i],(int) inst.classValue())); } return llsum / (double) m_validationSet.numInstances(); } /* calculates the log likelihood of an instance */ private double logLikelihood(double[] Fs, int classIndex) throws Exception { return -Math.log(distributionForInstance(Fs)[classIndex]); } /* calculates the validation error of the committee */ public double validationError() throws Exception { if (m_modelHasChanged) { Instance inst; int numIncorrect = 0; for (int i=0; i<m_validationSet.numInstances(); i++) { inst = m_validationSet.instance(i); if (classifyInstance(m_validationFs[i]) != inst.classValue()) numIncorrect++; } m_lastValidationError = (double) numIncorrect / (double) m_validationSet.numInstances(); m_modelHasChanged = false; } return m_lastValidationError; } /* returns the chunk size used by the committee */ public int chunkSize() { return m_chunkSize; } /* returns the number of models in the committee */ public int committeeSize() { return m_models.size(); } /* classifies an instance (given Fs values) with the committee */ public double classifyInstance(double[] Fs) throws Exception { double [] dist = distributionForInstance(Fs); double max = 0; int maxIndex = 0; for (int i = 0; i < dist.length; i++) { if (dist[i] > max) { maxIndex = i; max = dist[i]; } } if (max > 0) { return maxIndex; } else { return Instance.missingValue(); } } /* classifies an instance with the committee */ public double classifyInstance(Instance instance) throws Exception { double [] dist = distributionForInstance(instance); switch (instance.classAttribute().type()) { case Attribute.NOMINAL: double max = 0; int maxIndex = 0; for (int i = 0; i < dist.length; i++) { if (dist[i] > max) { maxIndex = i; max = dist[i]; } } if (max > 0) { return maxIndex; } else { return Instance.missingValue(); } case Attribute.NUMERIC: return dist[0]; default: return Instance.missingValue(); } } /* returns the distribution the committee generates for an instance (given Fs values) */ public double[] distributionForInstance(double[] Fs) throws Exception { double [] distribution = new double [m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { distribution[j] = RtoP(Fs, j); } return distribution; } /* updates the Fs values given a new model in the committee */ public double[] updateFS(Instance instance, Classifier[] newModel, double[] Fs) throws Exception { instance = (Instance)instance.copy(); instance.setDataset(m_NumericClassData); double [] Fi = new double [m_NumClasses]; double Fsum = 0; for (int j = 0; j < m_NumClasses; j++) { Fi[j] = newModel[j].classifyInstance(instance); Fsum += Fi[j]; } Fsum /= m_NumClasses; double[] newFs = new double[Fs.length]; for (int j = 0; j < m_NumClasses; j++) { newFs[j] = Fs[j] + ((Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses); } return newFs; } /* returns the distribution the committee generates for an instance */ public double[] distributionForInstance(Instance instance) throws Exception { instance = (Instance)instance.copy(); instance.setDataset(m_NumericClassData); double [] Fs = new double [m_NumClasses]; for (int i = 0; i < m_models.size(); i++) { double [] Fi = new double [m_NumClasses]; double Fsum = 0; Classifier[] model = (Classifier[]) m_models.elementAt(i); for (int j = 0; j < m_NumClasses; j++) { Fi[j] = model[j].classifyInstance(instance); Fsum += Fi[j]; } Fsum /= m_NumClasses; for (int j = 0; j < m_NumClasses; j++) { Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses; } } double [] distribution = new double [m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { distribution[j] = RtoP(Fs, j); } return distribution; } /* performs a boosting iteration, returning a new model for the committee */ protected Classifier[] boost(Instances data) throws Exception { Classifier[] newModel = Classifier.makeCopies(m_Classifier, m_NumClasses); // Create a copy of the data with the class transformed into numeric Instances boostData = new Instances(data); boostData.deleteWithMissingClass(); int numInstances = boostData.numInstances();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -