📄 smo.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SMO.java * Copyright (C) 1999 Eibe Frank * */package weka.classifiers.functions;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.functions.supportVector.Kernel;import weka.classifiers.functions.supportVector.PolyKernel;import weka.classifiers.functions.supportVector.SMOset;import weka.core.Attribute;import weka.core.Capabilities;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.SelectedTag;import weka.core.SerializedObject;import weka.core.Tag;import weka.core.TechnicalInformation;import weka.core.TechnicalInformationHandler;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;import weka.core.TechnicalInformation.Field;import weka.core.TechnicalInformation.Type;import weka.filters.Filter;import weka.filters.unsupervised.attribute.NominalToBinary;import weka.filters.unsupervised.attribute.Normalize;import weka.filters.unsupervised.attribute.ReplaceMissingValues;import weka.filters.unsupervised.attribute.Standardize;import java.io.Serializable;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Implements John Platt's sequential minimal optimization algorithm for training a support vector classifier.<br/> * <br/> * This implementation globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes by default. (In that case the coefficients in the output are based on the normalized data, not the original data --- this is important for interpreting the classifier.)<br/> * <br/> * Multi-class problems are solved using pairwise classification (1-vs-1 and if logistic models are built pairwise coupling according to Hastie and Tibshirani, 1998).<br/> * <br/> * To obtain proper probability estimates, use the option that fits logistic regression models to the outputs of the support vector machine. In the multi-class case the predicted probabilities are coupled using Hastie and Tibshirani's pairwise coupling method.<br/> * <br/> * Note: for improved speed normalization should be turned off when operating on SparseInstances.<br/> * <br/> * For more information on the SMO algorithm, see<br/> * <br/> * J. Platt: Machines using Sequential Minimal Optimization. In B. Schoelkopf and C. Burges and A. Smola, editors, Advances in Kernel Methods - Support Vector Learning, 1998.<br/> * <br/> * S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy (2001). Improvements to Platt's SMO Algorithm for SVM Classifier Design. Neural Computation. 13(3):637-649.<br/> * <br/> * Trevor Hastie, Robert Tibshirani: Classification by Pairwise Coupling. In: Advances in Neural Information Processing Systems, 1998. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @incollection{Platt1998, * author = {J. Platt}, * booktitle = {Advances in Kernel Methods - Support Vector Learning}, * editor = {B. Schoelkopf and C. Burges and A. Smola}, * publisher = {MIT Press}, * title = {Machines using Sequential Minimal Optimization}, * year = {1998}, * URL = {http://research.microsoft.com/~jplatt/smo.html}, * PS = {http://research.microsoft.com/~jplatt/smo-book.ps.gz}, * PDF = {http://research.microsoft.com/~jplatt/smo-book.pdf} * } * * @article{Keerthi2001, * author = {S.S. Keerthi and S.K. Shevade and C. Bhattacharyya and K.R.K. Murthy}, * journal = {Neural Computation}, * number = {3}, * pages = {637-649}, * title = {Improvements to Platt's SMO Algorithm for SVM Classifier Design}, * volume = {13}, * year = {2001}, * PS = {http://guppy.mpe.nus.edu.sg/~mpessk/svm/smo_mod_nc.ps.gz} * } * * @inproceedings{Hastie1998, * author = {Trevor Hastie and Robert Tibshirani}, * booktitle = {Advances in Neural Information Processing Systems}, * editor = {Michael I. Jordan and Michael J. Kearns and Sara A. Solla}, * publisher = {MIT Press}, * title = {Classification by Pairwise Coupling}, * volume = {10}, * year = {1998}, * PS = {http://www-stat.stanford.edu/~hastie/Papers/2class.ps} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -no-checks * Turns off all checks - use with caution! * Turning them off assumes that data is purely numeric, doesn't * contain any missing values, and has a nominal class. Turning them * off also means that no header information will be stored if the * machine is linear. Finally, it also assumes that no instance has * a weight equal to 0. * (default: checks on)</pre> * * <pre> -C <double> * The complexity constant C. (default 1)</pre> * * <pre> -N * Whether to 0=normalize/1=standardize/2=neither. (default 0=normalize)</pre> * * <pre> -L <double> * The tolerance parameter. (default 1.0e-3)</pre> * * <pre> -P <double> * The epsilon for round-off error. (default 1.0e-12)</pre> * * <pre> -M * Fit logistic models to SVM outputs. </pre> * * <pre> -V <double> * The number of folds for the internal * cross-validation. (default -1, use training data)</pre> * * <pre> -W <double> * The random number seed. (default 1)</pre> * * <pre> -K <classname and parameters> * The Kernel to use. * (default: weka.classifiers.functions.supportVector.PolyKernel)</pre> * * <pre> * Options specific to kernel weka.classifiers.functions.supportVector.PolyKernel: * </pre> * * <pre> -D * Enables debugging output (if available) to be printed. * (default: off)</pre> * * <pre> -no-checks * Turns off all checks - use with caution! * (default: checks on)</pre> * * <pre> -C <num> * The size of the cache (a prime number). * (default: 250007)</pre> * * <pre> -E <num> * The Exponent to use. * (default: 1.0)</pre> * * <pre> -L * Use lower-order terms. * (default: no)</pre> * <!-- options-end --> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Shane Legg (shane@intelligenesis.net) (sparse vector code) * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code) * @version $Revision: 1.61 $ */public class SMO extends Classifier implements WeightedInstancesHandler, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = -6585883636378691736L; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Implements John Platt's sequential minimal optimization " + "algorithm for training a support vector classifier.\n\n" + "This implementation globally replaces all missing values and " + "transforms nominal attributes into binary ones. It also " + "normalizes all attributes by default. (In that case the coefficients " + "in the output are based on the normalized data, not the " + "original data --- this is important for interpreting the classifier.)\n\n" + "Multi-class problems are solved using pairwise classification " + "(1-vs-1 and if logistic models are built pairwise coupling " + "according to Hastie and Tibshirani, 1998).\n\n" + "To obtain proper probability estimates, use the option that fits " + "logistic regression models to the outputs of the support vector " + "machine. In the multi-class case the predicted probabilities " + "are coupled using Hastie and Tibshirani's pairwise coupling " + "method.\n\n" + "Note: for improved speed normalization should be turned off when " + "operating on SparseInstances.\n\n" + "For more information on the SMO algorithm, see\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; TechnicalInformation additional; result = new TechnicalInformation(Type.INCOLLECTION); result.setValue(Field.AUTHOR, "J. Platt"); result.setValue(Field.YEAR, "1998"); result.setValue(Field.TITLE, "Machines using Sequential Minimal Optimization"); result.setValue(Field.BOOKTITLE, "Advances in Kernel Methods - Support Vector Learning"); result.setValue(Field.EDITOR, "B. Schoelkopf and C. Burges and A. Smola"); result.setValue(Field.PUBLISHER, "MIT Press"); result.setValue(Field.URL, "http://research.microsoft.com/~jplatt/smo.html"); result.setValue(Field.PDF, "http://research.microsoft.com/~jplatt/smo-book.pdf"); result.setValue(Field.PS, "http://research.microsoft.com/~jplatt/smo-book.ps.gz"); additional = result.add(Type.ARTICLE); additional.setValue(Field.AUTHOR, "S.S. Keerthi and S.K. Shevade and C. Bhattacharyya and K.R.K. Murthy"); additional.setValue(Field.YEAR, "2001"); additional.setValue(Field.TITLE, "Improvements to Platt's SMO Algorithm for SVM Classifier Design"); additional.setValue(Field.JOURNAL, "Neural Computation"); additional.setValue(Field.VOLUME, "13"); additional.setValue(Field.NUMBER, "3"); additional.setValue(Field.PAGES, "637-649"); additional.setValue(Field.PS, "http://guppy.mpe.nus.edu.sg/~mpessk/svm/smo_mod_nc.ps.gz"); additional = result.add(Type.INPROCEEDINGS); additional.setValue(Field.AUTHOR, "Trevor Hastie and Robert Tibshirani"); additional.setValue(Field.YEAR, "1998"); additional.setValue(Field.TITLE, "Classification by Pairwise Coupling"); additional.setValue(Field.BOOKTITLE, "Advances in Neural Information Processing Systems"); additional.setValue(Field.VOLUME, "10"); additional.setValue(Field.PUBLISHER, "MIT Press"); additional.setValue(Field.EDITOR, "Michael I. Jordan and Michael J. Kearns and Sara A. Solla"); additional.setValue(Field.PS, "http://www-stat.stanford.edu/~hastie/Papers/2class.ps"); return result; } /** * Class for building a binary support vector machine. */ public class BinarySMO implements Serializable { /** for serialization */ static final long serialVersionUID = -8246163625699362456L; /** The Lagrange multipliers. */ protected double[] m_alpha; /** The thresholds. */ protected double m_b, m_bLow, m_bUp; /** The indices for m_bLow and m_bUp */ protected int m_iLow, m_iUp; /** The training data. */ protected Instances m_data; /** Weight vector for linear machine. */ protected double[] m_weights; /** Variables to hold weight vector in sparse form. (To reduce storage requirements.) */ protected double[] m_sparseWeights; protected int[] m_sparseIndices; /** Kernel to use **/ protected Kernel m_kernel; /** The transformed class values. */ protected double[] m_class; /** The current set of errors for all non-bound examples. */ protected double[] m_errors; /* The five different sets used by the algorithm. */ /** {i: 0 < m_alpha[i] < C} */ protected SMOset m_I0; /** {i: m_class[i] = 1, m_alpha[i] = 0} */ protected SMOset m_I1; /** {i: m_class[i] = -1, m_alpha[i] =C} */ protected SMOset m_I2; /** {i: m_class[i] = 1, m_alpha[i] = C} */ protected SMOset m_I3; /** {i: m_class[i] = -1, m_alpha[i] = 0} */ protected SMOset m_I4; /** The set of support vectors */ protected SMOset m_supportVectors; // {i: 0 < m_alpha[i]} /** Stores logistic regression model for probability estimate */ protected Logistic m_logistic = null; /** Stores the weight of the training instances */ protected double m_sumOfWeights = 0; /** * Fits logistic regression model to SVM outputs analogue * to John Platt's method. * * @param insts the set of training instances * @param cl1 the first class' index * @param cl2 the second class' index * @param numFolds the number of folds for cross-validation * @param random for randomizing the data * @throws Exception if the sigmoid can't be fit successfully */ protected void fitLogistic(Instances insts, int cl1, int cl2, int numFolds, Random random) throws Exception { // Create header of instances object FastVector atts = new FastVector(2); atts.addElement(new Attribute("pred")); FastVector attVals = new FastVector(2); attVals.addElement(insts.classAttribute().value(cl1)); attVals.addElement(insts.classAttribute().value(cl2)); atts.addElement(new Attribute("class", attVals)); Instances data = new Instances("data", atts, insts.numInstances()); data.setClassIndex(1); // Collect data for fitting the logistic model if (numFolds <= 0) { // Use training data for (int j = 0; j < insts.numInstances(); j++) { Instance inst = insts.instance(j); double[] vals = new double[2]; vals[0] = SVMOutput(-1, inst); if (inst.classValue() == cl2) { vals[1] = 1; } data.add(new Instance(inst.weight(), vals)); } } else { // Check whether number of folds too large if (numFolds > insts.numInstances()) { numFolds = insts.numInstances(); } // Make copy of instances because we will shuffle them around insts = new Instances(insts); // Perform three-fold cross-validation to collect // unbiased predictions insts.randomize(random); insts.stratify(numFolds); for (int i = 0; i < numFolds; i++) { Instances train = insts.trainCV(numFolds, i, random); SerializedObject so = new SerializedObject(this); BinarySMO smo = (BinarySMO)so.getObject(); smo.buildClassifier(train, cl1, cl2, false, -1, -1); Instances test = insts.testCV(numFolds, i); for (int j = 0; j < test.numInstances(); j++) { double[] vals = new double[2]; vals[0] = smo.SVMOutput(-1, test.instance(j)); if (test.instance(j).classValue() == cl2) { vals[1] = 1; } data.add(new Instance(test.instance(j).weight(), vals)); } } } // Build logistic regression model m_logistic = new Logistic(); m_logistic.buildClassifier(data); } /** * sets the kernel to use * * @param value the kernel to use */ public void setKernel(Kernel value) { m_kernel = value; } /** * Returns the kernel to use * * @return the current kernel */ public Kernel getKernel() { return m_kernel; } /** * Method for building the binary classifier. * * @param insts the set of training instances * @param cl1 the first class' index * @param cl2 the second class' index * @param fitLogistic true if logistic model is to be fit * @param numFolds number of folds for internal cross-validation * @param randomSeed random number generator for cross-validation * @throws Exception if the classifier can't be built successfully */ protected void buildClassifier(Instances insts, int cl1, int cl2, boolean fitLogistic, int numFolds, int randomSeed) throws Exception { // Initialize some variables
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -