📄 lmt.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LMT.java
* Copyright (C) 2003 Niels Landwehr
*
*/
package weka.classifiers.trees;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.j48.C45ModelSelection;
import weka.classifiers.trees.j48.ModelSelection;
import weka.classifiers.trees.lmt.LMTNode;
import weka.classifiers.trees.lmt.ResidualModelSelection;
import weka.core.AdditionalMeasureProducer;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
/**
* Class for "logistic model tree" classifier.
* For more information, see master thesis "Logistic Model Trees" (Niels Landwehr, 2003)<p>
*
* Valid options are: <p>
*
* -B <br>
* Binary splits (convert nominal attributes to binary ones).<p>
*
* -R <br>
* Split on residuals instead of class values <p>
*
* -C <br>
* Use cross-validation for boosting at all nodes (i.e., disable heuristic) <p>
*
* -P <br>
* Use error on probabilities instead of misclassification error for
* stopping criterion of LogitBoost. <p>
*
* -I iterations <br>
* Set fixed number of iterations for LogitBoost (instead of using cross-validation). <p>
*
* -M numInstances <br>
*
* Set minimum number of instances at which a node can be split (default 15)
*
*
* @author Niels Landwehr
* @version $Revision$
*/
public class LMT extends Classifier implements OptionHandler, AdditionalMeasureProducer,
Drawable{
//format of serial: 1**date## (** = algorithm id, ##= version)
//static final long serialVersionUID = 1010506200300L;
/** Filter to replace missing values*/
protected ReplaceMissingValues m_replaceMissing;
/** Filter to replace nominal attributes*/
protected NominalToBinary m_nominalToBinary;
/** root of the logistic model tree*/
protected LMTNode m_tree;
/** use heuristic that determines the number of LogitBoost iterations only once in the beginning?*/
protected boolean m_fastRegression;
/** convert nominal attributes to binary ?*/
protected boolean m_convertNominal;
/** split on residuals?*/
protected boolean m_splitOnResiduals;
/**use error on probabilties instead of misclassification for stopping criterion of LogitBoost?*/
protected boolean m_errorOnProbabilities;
/**minimum number of instances at which a node is considered for splitting*/
protected int m_minNumInstances;
/**if non-zero, use fixed number of iterations for LogitBoost*/
protected int m_numBoostingIterations;
/**
* Creates an instance of LMT with standard options
*/
public LMT() {
m_fastRegression = true;
m_numBoostingIterations = -1;
m_minNumInstances = 15;
}
/**
* Builds the classifier.
*
* @exception Exception if classifier can't be built successfully
*/
public void buildClassifier(Instances data) throws Exception{
// Check for non-nominal classes
if (!data.classAttribute().isNominal()) {
throw new UnsupportedClassTypeException("Nominal class, please.");
}
Instances filteredData = new Instances(data);
filteredData.deleteWithMissingClass();
if (data.numInstances() == 0) {
throw new Exception("No instances without missing class values in training file!");
}
//replace missing values
m_replaceMissing = new ReplaceMissingValues();
m_replaceMissing.setInputFormat(filteredData);
filteredData = Filter.useFilter(filteredData, m_replaceMissing);
//possibly convert nominal attributes globally
if (m_convertNominal) {
m_nominalToBinary = new NominalToBinary();
m_nominalToBinary.setInputFormat(filteredData);
filteredData = Filter.useFilter(filteredData, m_nominalToBinary);
}
int minNumInstances = 2;
//create ModelSelection object, either for splits on the residuals or for splits on the class value
ModelSelection modSelection;
if (m_splitOnResiduals) {
modSelection = new ResidualModelSelection(minNumInstances);
} else {
modSelection = new C45ModelSelection(minNumInstances, filteredData);
}
//create tree root
m_tree = new LMTNode(modSelection, m_numBoostingIterations, m_fastRegression,
m_errorOnProbabilities, m_minNumInstances);
//build tree
m_tree.buildClassifier(filteredData);
if (modSelection instanceof C45ModelSelection) ((C45ModelSelection)modSelection).cleanup();
}
/**
* Returns class probabilities for an instance.
*
* @exception Exception if distribution can't be computed successfully
*/
public double [] distributionForInstance(Instance instance) throws Exception {
//replace missing values
m_replaceMissing.input(instance);
instance = m_replaceMissing.output();
//possibly convert nominal attributes
if (m_convertNominal) {
m_nominalToBinary.input(instance);
instance = m_nominalToBinary.output();
}
return m_tree.distributionForInstance(instance);
}
/**
* Classifies an instance.
*
* @exception Exception if instance can't be classified successfully
*/
public double classifyInstance(Instance instance) throws Exception {
double maxProb = -1;
int maxIndex = 0;
//classify by maximum probability
double[] probs = distributionForInstance(instance);
for (int j = 0; j < instance.numClasses(); j++) {
if (Utils.gr(probs[j], maxProb)) {
maxIndex = j;
maxProb = probs[j];
}
}
return (double)maxIndex;
}
/**
* Returns a description of the classifier.
*/
public String toString() {
if (m_tree!=null) {
return "Logistic model tree \n------------------\n" + m_tree.toString();
} else {
return "No tree build";
}
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(8);
newVector.addElement(new Option("\tBinary splits (convert nominal attributes to binary ones)\n",
"B", 0, "-B"));
newVector.addElement(new Option("\tSplit on residuals instead of class values\n",
"R", 0, "-R"));
newVector.addElement(new Option("\tUse cross-validation for boosting at all nodes (i.e., disable heuristic)\n",
"C", 0, "-C"));
newVector.addElement(new Option("\tUse error on probabilities instead of misclassification error "+
"for stopping criterion of LogitBoost.\n",
"P", 0, "-P"));
newVector.addElement(new Option("\tSet fixed number of iterations for LogitBoost (instead of using "+
"cross-validation)\n",
"I",1,"-I <numIterations>"));
newVector.addElement(new Option("\tSet minimum number of instances at which a node can be split (default 15)\n",
"M",1,"-M <numInstances>"));
return newVector.elements();
}
/**
* Parses a given list of options.
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
setConvertNominal(Utils.getFlag('B', options));
setSplitOnResiduals(Utils.getFlag('R', options));
setFastRegression(!Utils.getFlag('C', options));
setErrorOnProbabilities(Utils.getFlag('P', options));
String optionString = Utils.getOption('I', options);
if (optionString.length() != 0) {
setNumBoostingIterations((new Integer(optionString)).intValue());
}
optionString = Utils.getOption('M', options);
if (optionString.length() != 0) {
setMinNumInstances((new Integer(optionString)).intValue());
}
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String[] getOptions() {
String[] options = new String[8];
int current = 0;
if (getConvertNominal()) {
options[current++] = "-B";
}
if (getSplitOnResiduals()) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -