📄 racedincrementallogitboost.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * RacedIncrementalLogitBoost.java * Copyright (C) 2002 Richard Kirkby, Eibe Frank * */package weka.classifiers.meta;import weka.classifiers.Classifier;import weka.classifiers.RandomizableSingleClassifierEnhancer;import weka.classifiers.UpdateableClassifier;import weka.classifiers.rules.ZeroR;import weka.core.Attribute;import weka.core.Capabilities;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.SelectedTag;import weka.core.Tag;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;import java.io.Serializable;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Classifier for incremental learning of large datasets by way of racing logit-boosted committees. * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -C <num> * Minimum size of chunks. * (default 500)</pre> * * <pre> -M <num> * Maximum size of chunks. * (default 2000)</pre> * * <pre> -V <num> * Size of validation set. * (default 1000)</pre> * * <pre> -P <pruning type> * Committee pruning to perform. * 0=none, 1=log likelihood (default)</pre> * * <pre> -Q * Use resampling for boosting.</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.trees.DecisionStump)</pre> * * <pre> * Options specific to classifier weka.classifiers.trees.DecisionStump: * </pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * Options after -- are passed to the designated learner.<p> * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.10 $ */public class RacedIncrementalLogitBoost extends RandomizableSingleClassifierEnhancer implements UpdateableClassifier { /** for serialization */ static final long serialVersionUID = 908598343772170052L; /** no pruning */ public static final int PRUNETYPE_NONE = 0; /** log likelihood pruning */ public static final int PRUNETYPE_LOGLIKELIHOOD = 1; /** The pruning types */ public static final Tag [] TAGS_PRUNETYPE = { new Tag(PRUNETYPE_NONE, "No pruning"), new Tag(PRUNETYPE_LOGLIKELIHOOD, "Log likelihood pruning") }; /** The committees */ protected FastVector m_committees; /** The pruning type used */ protected int m_PruningType = PRUNETYPE_LOGLIKELIHOOD; /** Whether to use resampling */ protected boolean m_UseResampling = false; /** The number of classes */ protected int m_NumClasses; /** A threshold for responses (Friedman suggests between 2 and 4) */ protected static final double Z_MAX = 4; /** Dummy dataset with a numeric class */ protected Instances m_NumericClassData; /** The actual class attribute (for getting class names) */ protected Attribute m_ClassAttribute; /** The minimum chunk size used for training */ protected int m_minChunkSize = 500; /** The maimum chunk size used for training */ protected int m_maxChunkSize = 2000; /** The size of the validation set */ protected int m_validationChunkSize = 1000; /** The number of instances consumed */ protected int m_numInstancesConsumed; /** The instances used for validation */ protected Instances m_validationSet; /** The instances currently in memory for training */ protected Instances m_currentSet; /** The current best committee */ protected Committee m_bestCommittee; /** The default scheme used when committees aren't ready */ protected ZeroR m_zeroR = null; /** Whether the validation set has recently been changed */ protected boolean m_validationSetChanged; /** The maximum number of instances required for processing */ protected int m_maxBatchSizeRequired; /** The random number generator used */ protected Random m_RandomInstance = null; /** * Constructor. */ public RacedIncrementalLogitBoost() { m_Classifier = new weka.classifiers.trees.DecisionStump(); } /** * String describing default classifier. * * @return the default classifier classname */ protected String defaultClassifierString() { return "weka.classifiers.trees.DecisionStump"; } /** * Class representing a committee of LogitBoosted models */ protected class Committee implements Serializable { /** for serialization */ static final long serialVersionUID = 5559880306684082199L; protected int m_chunkSize; /** number eaten from m_currentSet */ protected int m_instancesConsumed; protected FastVector m_models; protected double m_lastValidationError; protected double m_lastLogLikelihood; protected boolean m_modelHasChanged; protected boolean m_modelHasChangedLL; protected double[][] m_validationFs; protected double[][] m_newValidationFs; /** * constructor * * @param chunkSize the size of the chunk */ public Committee(int chunkSize) { m_chunkSize = chunkSize; m_instancesConsumed = 0; m_models = new FastVector(); m_lastValidationError = 1.0; m_lastLogLikelihood = Double.MAX_VALUE; m_modelHasChanged = true; m_modelHasChangedLL = true; m_validationFs = new double[m_validationChunkSize][m_NumClasses]; m_newValidationFs = new double[m_validationChunkSize][m_NumClasses]; } /** * update the committee * * @return true if the committee has changed * @throws Exception if anything goes wrong */ public boolean update() throws Exception { boolean hasChanged = false; while (m_currentSet.numInstances() - m_instancesConsumed >= m_chunkSize) { Classifier[] newModel = boost(new Instances(m_currentSet, m_instancesConsumed, m_chunkSize)); for (int i=0; i<m_validationSet.numInstances(); i++) { m_newValidationFs[i] = updateFS(m_validationSet.instance(i), newModel, m_validationFs[i]); } m_models.addElement(newModel); m_instancesConsumed += m_chunkSize; hasChanged = true; } if (hasChanged) { m_modelHasChanged = true; m_modelHasChangedLL = true; } return hasChanged; } /** reset consumation counts */ public void resetConsumed() { m_instancesConsumed = 0; } /** remove the last model from the committee */ public void pruneLastModel() { if (m_models.size() > 0) { m_models.removeElementAt(m_models.size()-1); m_modelHasChanged = true; m_modelHasChangedLL = true; } } /** * decide to keep the last model in the committee * @throws Exception if anything goes wrong */ public void keepLastModel() throws Exception { m_validationFs = m_newValidationFs; m_newValidationFs = new double[m_validationChunkSize][m_NumClasses]; m_modelHasChanged = true; m_modelHasChangedLL = true; } /** * calculate the log likelihood on the validation data * @return the log likelihood * @throws Exception if computation fails */ public double logLikelihood() throws Exception { if (m_modelHasChangedLL) { Instance inst; double llsum = 0.0; for (int i=0; i<m_validationSet.numInstances(); i++) { inst = m_validationSet.instance(i); llsum += (logLikelihood(m_validationFs[i],(int) inst.classValue())); } m_lastLogLikelihood = llsum / (double) m_validationSet.numInstances(); m_modelHasChangedLL = false; } return m_lastLogLikelihood; } /** * calculate the log likelihood on the validation data after adding the last model * @return the log likelihood * @throws Exception if computation fails */ public double logLikelihoodAfter() throws Exception { Instance inst; double llsum = 0.0; for (int i=0; i<m_validationSet.numInstances(); i++) { inst = m_validationSet.instance(i); llsum += (logLikelihood(m_newValidationFs[i],(int) inst.classValue())); } return llsum / (double) m_validationSet.numInstances(); } /** * calculates the log likelihood of an instance * @param Fs the Fs values * @param classIndex the class index * @return the log likelihood * @throws Exception if computation fails */ private double logLikelihood(double[] Fs, int classIndex) throws Exception { return -Math.log(distributionForInstance(Fs)[classIndex]); } /** * calculates the validation error of the committee * @return the validation error * @throws Exception if computation fails */ public double validationError() throws Exception { if (m_modelHasChanged) { Instance inst; int numIncorrect = 0; for (int i=0; i<m_validationSet.numInstances(); i++) { inst = m_validationSet.instance(i); if (classifyInstance(m_validationFs[i]) != inst.classValue()) numIncorrect++; } m_lastValidationError = (double) numIncorrect / (double) m_validationSet.numInstances(); m_modelHasChanged = false; } return m_lastValidationError; } /** * returns the chunk size used by the committee * * @return the chunk size */ public int chunkSize() { return m_chunkSize; } /** * returns the number of models in the committee * * @return the committee size */ public int committeeSize() { return m_models.size(); } /** * classifies an instance (given Fs values) with the committee * * @param Fs the Fs values * @return the classification * @throws Exception if anything goes wrong */ public double classifyInstance(double[] Fs) throws Exception { double [] dist = distributionForInstance(Fs); double max = 0; int maxIndex = 0; for (int i = 0; i < dist.length; i++) { if (dist[i] > max) { maxIndex = i; max = dist[i]; } } if (max > 0) { return maxIndex; } else { return Instance.missingValue(); } } /** * classifies an instance with the committee * * @param instance the instance to classify * @return the classification * @throws Exception if anything goes wrong */ public double classifyInstance(Instance instance) throws Exception { double [] dist = distributionForInstance(instance); switch (instance.classAttribute().type()) { case Attribute.NOMINAL: double max = 0; int maxIndex = 0; for (int i = 0; i < dist.length; i++) { if (dist[i] > max) { maxIndex = i; max = dist[i];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -