📄 simplelogistic.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* SimpleLogistic.java
* Copyright (C) 2003 Niels Landwehr
*
*/
package weka.classifiers.functions;
import weka.classifiers.*;
import weka.classifiers.trees.lmt.LogisticBase;
import weka.core.*;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.Filter;
import java.util.*;
/**
* Class for building a logistic regression model using LogitBoost.
* Incorporates attribute selection by fitting simple regression functions in LogitBoost.
* For more information, see master thesis "Logistic Model Trees" (Niels Landwehr, 2003)<p>
*
* Valid options are: <p>
*
* -I iterations <br>
* Set fixed number of iterations for LogitBoost (instead of using cross-validation). <p>
* -S <br>
* Select the number of LogitBoost iterations that gives minimal error on the training set
* (instead of using cross-validation). <p>
* -P <br>
* Minimize error on probabilities instead of misclassification error. <p>
* -M iterations <br>
* Set maximum number of iterations for LogitBoost. <p>
* -H iter <br>
* Set parameter for heuristic for early stopping of LogitBoost.
* If enabled, the minimum is selected greedily, stopping if the current minimum has not changed
* for iter iterations. By default, heuristic is enabled with value 50. Set to zero to disable heuristic.
*
* @author Niels Landwehr
* @version $Revision: 1.1 $
*/
public class SimpleLogistic extends Classifier
implements OptionHandler, AdditionalMeasureProducer, WeightedInstancesHandler {
//format of serial: 1**date## (** = algorithm id, ##= version)
//static final long serialVersionUID = 1110506200300L;
/**The actual logistic regression model */
protected LogisticBase m_boostedModel;
/**Filter for converting nominal attributes to binary ones*/
protected NominalToBinary m_NominalToBinary = null;
/**Filter for replacing missing values*/
protected ReplaceMissingValues m_ReplaceMissingValues = null;
/**If non-negative, use this as fixed number of LogitBoost iterations*/
protected int m_numBoostingIterations;
/**Maximum number of iterations for LogitBoost*/
protected int m_maxBoostingIterations = 500;
/**Parameter for the heuristic for early stopping of LogitBoost*/
protected int m_heuristicStop = 50;
/**If true, cross-validate number of LogitBoost iterations*/
protected boolean m_useCrossValidation;
/**If true, use minimize error on probabilities instead of misclassification error*/
protected boolean m_errorOnProbabilities;
/**
* Constructor for creating SimpleLogistic object with standard options.
*/
public SimpleLogistic() {
m_numBoostingIterations = 0;
m_useCrossValidation = true;
m_errorOnProbabilities = false;
}
/**
* Constructor for creating SimpleLogistic object.
* @param numBoostingIterations if non-negative, use this as fixed number of iterations for LogitBoost
* @param useCrossValidation cross-validate number of LogitBoost iterations.
* @param errorOnProbabilities minimize error on probabilities instead of misclassification error
*/
public SimpleLogistic(int numBoostingIterations, boolean useCrossValidation,
boolean errorOnProbabilities) {
m_numBoostingIterations = numBoostingIterations;
m_useCrossValidation = useCrossValidation;
m_errorOnProbabilities = errorOnProbabilities;
}
/**
* Builds the logistic regression using LogitBoost.
* @param data the training data
* @exception Exception if something goes wrong
*/
public void buildClassifier(Instances data) throws Exception {
if (data.classAttribute().type() != Attribute.NOMINAL) {
throw new UnsupportedClassTypeException("Class attribute must be nominal.");
}
if (data.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
data = new Instances(data);
data.deleteWithMissingClass();
if (data.numInstances() == 0) {
throw new Exception("No instances without missing class values in training file!");
}
//replace missing values
m_ReplaceMissingValues = new ReplaceMissingValues();
m_ReplaceMissingValues.setInputFormat(data);
data = Filter.useFilter(data, m_ReplaceMissingValues);
//convert nominal attributes
m_NominalToBinary = new NominalToBinary();
m_NominalToBinary.setInputFormat(data);
data = Filter.useFilter(data, m_NominalToBinary);
//create actual logistic model
m_boostedModel = new LogisticBase(m_numBoostingIterations, m_useCrossValidation, m_errorOnProbabilities);
m_boostedModel.setMaxIterations(m_maxBoostingIterations);
m_boostedModel.setHeuristicStop(m_heuristicStop);
//build logistic model
m_boostedModel.buildClassifier(data);
}
/**
* Returns class probabilities for an instance.
*
* @exception Exception if distribution can't be computed successfully
*/
public double[] distributionForInstance(Instance inst)
throws Exception {
//replace missing values / convert nominal atts
m_ReplaceMissingValues.input(inst);
inst = m_ReplaceMissingValues.output();
m_NominalToBinary.input(inst);
inst = m_NominalToBinary.output();
//obtain probs from logistic model
return m_boostedModel.distributionForInstance(inst);
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(5);
newVector.addElement(new Option("\tSet fixed number of iterations for LogitBoost\n",
"I",1,"-I <iterations>"));
newVector.addElement(new Option("\tUse stopping criterion on training set (instead of cross-validation)\n",
"S",0,"-S"));
newVector.addElement(new Option("\tUse error on probabilities (rmse) instead of misclassification error " +
"for stopping criterion\n",
"P",0,"-P"));
newVector.addElement(new Option("\tSet maximum number of boosting iterations\n",
"M",1,"-M <iterations>"));
newVector.addElement(new Option("\tSet parameter for heuristic for early stopping of LogitBoost."+
"If enabled, the minimum is selected greedily, stopping if the current minimum"+
" has not changed for iter iterations. By default, heuristic is enabled with"+
"value 50. Set to zero to disable heuristic."+
"\n",
"H",1,"-H <iterations>"));
return newVector.elements();
}
/**
* Parses a given list of options.
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String optionString = Utils.getOption('I', options);
if (optionString.length() != 0) {
setNumBoostingIterations((new Integer(optionString)).intValue());
}
setUseCrossValidation(!Utils.getFlag('S', options));
setErrorOnProbabilities(Utils.getFlag('P', options));
optionString = Utils.getOption('M', options);
if (optionString.length() != 0) {
setMaxBoostingIterations((new Integer(optionString)).intValue());
}
optionString = Utils.getOption('H', options);
if (optionString.length() != 0) {
setHeuristicStop((new Integer(optionString)).intValue());
}
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String[] getOptions() {
String[] options = new String[9];
int current = 0;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -