📄 smo.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* SMO.java
* Copyright (C) 1999 Eibe Frank
*
*/
package weka.classifiers.functions;
import weka.classifiers.functions.supportVector.*;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.Logistic;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.Standardize;
import weka.filters.Filter;
import java.util.*;
import java.io.*;
import weka.core.*;
/**
* Implements John C. Platt's sequential minimal optimization
* algorithm for training a support vector classifier using polynomial
* or RBF kernels.
*
* This implementation globally replaces all missing values and
* transforms nominal attributes into binary ones. It also
* normalizes all attributes by default. (Note that the coefficients
* in the output are based on the normalized/standardized data, not the
* original data.)
*
* Multi-class problems are solved using pairwise classification.
*
* To obtain proper probability estimates, use the option that fits
* logistic regression models to the outputs of the support vector
* machine. In the multi-class case the predicted probabilities
* will be coupled using Hastie and Tibshirani's pairwise coupling
* method.
*
* Note: for improved speed standardization should be turned off when
* operating on SparseInstances.<p>
*
* For more information on the SMO algorithm, see<p>
*
* J. Platt (1998). <i>Fast Training of Support Vector
* Machines using Sequential Minimal Optimization</i>. Advances in Kernel
* Methods - Support Vector Learning, B. Schoelkopf, C. Burges, and
* A. Smola, eds., MIT Press. <p>
*
* S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy,
* <i>Improvements to Platt's SMO Algorithm for SVM Classifier Design</i>.
* Neural Computation, 13(3), pp 637-649, 2001. <p>
*
* Valid options are:<p>
*
* -C num <br>
* The complexity constant C. (default 1)<p>
*
* -E num <br>
* The exponent for the polynomial kernel. (default 1)<p>
*
* -G num <br>
* Gamma for the RBF kernel. (default 0.01)<p>
*
* -N <0|1|2> <br>
* Whether to 0=normalize/1=standardize/2=neither. (default 0=normalize)<p>
*
* -F <br>
* Feature-space normalization (only for non-linear polynomial kernels). <p>
*
* -O <br>
* Use lower-order terms (only for non-linear polynomial kernels). <p>
*
* -R <br>
* Use the RBF kernel. (default poly)<p>
*
* -A num <br>
* Sets the size of the kernel cache. Should be a prime number.
* (default 250007, use 0 for full cache) <p>
*
* -L num <br>
* Sets the tolerance parameter. (default 1.0e-3)<p>
*
* -P num <br>
* Sets the epsilon for round-off error. (default 1.0e-12)<p>
*
* -M <br>
* Fit logistic models to SVM outputs.<p>
*
* -V num <br>
* Number of folds for cross-validation used to generate data
* for logistic models. (default -1, use training data)
*
* -W num <br>
* Random number seed for cross-validation. (default 1)
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @author Shane Legg (shane@intelligenesis.net) (sparse vector code)
* @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code)
* @version $Revision: 1.1 $ */
public class SMO extends Classifier implements WeightedInstancesHandler {
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Implements John Platt's sequential minimal optimization "
+ "algorithm for training a support vector classifier.\n\n"
+ "This implementation globally replaces all missing values and "
+ "transforms nominal attributes into binary ones. It also "
+ "normalizes all attributes by default. (In that case the coefficients "
+ "in the output are based on the normalized data, not the "
+ "original data --- this is important for interpreting the classifier.)\n\n"
+ "Multi-class problems are solved using pairwise classification.\n\n"
+ "To obtain proper probability estimates, use the option that fits "
+ "logistic regression models to the outputs of the support vector "
+ "machine. In the multi-class case the predicted probabilities "
+ "are coupled using Hastie and Tibshirani's pairwise coupling "
+ "method.\n\n"
+ "Note: for improved speed normalization should be turned off when "
+ "operating on SparseInstances.\n\n"
+ "For more information on the SMO algorithm, see\n\n"
+ "J. Platt (1998). \"Fast Training of Support Vector "
+ "Machines using Sequential Minimal Optimization\". Advances in Kernel "
+ "Methods - Support Vector Learning, B. Schoelkopf, C. Burges, and "
+ "A. Smola, eds., MIT Press. \n\n"
+ "S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy, "
+ "\"Improvements to Platt's SMO Algorithm for SVM Classifier Design\". "
+ "Neural Computation, 13(3), pp 637-649, 2001.";
}
/**
* Class for building a binary support vector machine.
*/
protected class BinarySMO implements Serializable {
/** The Lagrange multipliers. */
protected double[] m_alpha;
/** The thresholds. */
protected double m_b, m_bLow, m_bUp;
/** The indices for m_bLow and m_bUp */
protected int m_iLow, m_iUp;
/** The training data. */
protected Instances m_data;
/** Weight vector for linear machine. */
protected double[] m_weights;
/** Variables to hold weight vector in sparse form.
(To reduce storage requirements.) */
protected double[] m_sparseWeights;
protected int[] m_sparseIndices;
/** Kernel to use **/
protected Kernel m_kernel;
/** The transformed class values. */
protected double[] m_class;
/** The current set of errors for all non-bound examples. */
protected double[] m_errors;
/** The five different sets used by the algorithm. */
protected SMOset m_I0; // {i: 0 < m_alpha[i] < C}
protected SMOset m_I1; // {i: m_class[i] = 1, m_alpha[i] = 0}
protected SMOset m_I2; // {i: m_class[i] = -1, m_alpha[i] =C}
protected SMOset m_I3; // {i: m_class[i] = 1, m_alpha[i] = C}
protected SMOset m_I4; // {i: m_class[i] = -1, m_alpha[i] = 0}
/** The set of support vectors */
protected SMOset m_supportVectors; // {i: 0 < m_alpha[i]}
/** Stores logistic regression model for probability estimate */
protected Logistic m_logistic = null;
/** Stores the weight of the training instances */
protected double m_sumOfWeights = 0;
/**
* Fits logistic regression model to SVM outputs analogue
* to John Platt's method.
*
* @param insts the set of training instances
* @param cl1 the first class' index
* @param cl2 the second class' index
* @exception Exception if the sigmoid can't be fit successfully
*/
protected void fitLogistic(Instances insts, int cl1, int cl2,
int numFolds, Random random)
throws Exception {
// Create header of instances object
FastVector atts = new FastVector(2);
atts.addElement(new Attribute("pred"));
FastVector attVals = new FastVector(2);
attVals.addElement(insts.classAttribute().value(cl1));
attVals.addElement(insts.classAttribute().value(cl2));
atts.addElement(new Attribute("class", attVals));
Instances data = new Instances("data", atts, insts.numInstances());
data.setClassIndex(1);
// Collect data for fitting the logistic model
if (numFolds <= 0) {
// Use training data
for (int j = 0; j < insts.numInstances(); j++) {
Instance inst = insts.instance(j);
double[] vals = new double[2];
vals[0] = SVMOutput(-1, inst);
if (inst.classValue() == cl2) {
vals[1] = 1;
}
data.add(new Instance(inst.weight(), vals));
}
} else {
// Check whether number of folds too large
if (numFolds > insts.numInstances()) {
numFolds = insts.numInstances();
}
// Make copy of instances because we will shuffle them around
insts = new Instances(insts);
// Perform three-fold cross-validation to collect
// unbiased predictions
insts.randomize(random);
insts.stratify(numFolds);
for (int i = 0; i < numFolds; i++) {
Instances train = insts.trainCV(numFolds, i, random);
SerializedObject so = new SerializedObject(this);
BinarySMO smo = (BinarySMO)so.getObject();
smo.buildClassifier(train, cl1, cl2, false, -1, -1);
Instances test = insts.testCV(numFolds, i);
for (int j = 0; j < test.numInstances(); j++) {
double[] vals = new double[2];
vals[0] = smo.SVMOutput(-1, test.instance(j));
if (test.instance(j).classValue() == cl2) {
vals[1] = 1;
}
data.add(new Instance(test.instance(j).weight(), vals));
}
}
}
// Build logistic regression model
m_logistic = new Logistic();
m_logistic.buildClassifier(data);
}
/**
* Method for building the binary classifier.
*
* @param insts the set of training instances
* @param cl1 the first class' index
* @param cl2 the second class' index
* @param fitLogistic true if logistic model is to be fit
* @param numFolds number of folds for internal cross-validation
* @param random random number generator for cross-validation
* @exception Exception if the classifier can't be built successfully
*/
protected void buildClassifier(Instances insts, int cl1, int cl2,
boolean fitLogistic, int numFolds,
int randomSeed) throws Exception {
// Initialize some variables
m_bUp = -1; m_bLow = 1; m_b = 0;
m_alpha = null; m_data = null; m_weights = null; m_errors = null;
m_logistic = null; m_I0 = null; m_I1 = null; m_I2 = null;
m_I3 = null; m_I4 = null; m_sparseWeights = null; m_sparseIndices = null;
// Store the sum of weights
m_sumOfWeights = insts.sumOfWeights();
// Set class values
m_class = new double[insts.numInstances()];
m_iUp = -1; m_iLow = -1;
for (int i = 0; i < m_class.length; i++) {
if ((int) insts.instance(i).classValue() == cl1) {
m_class[i] = -1; m_iLow = i;
} else if ((int) insts.instance(i).classValue() == cl2) {
m_class[i] = 1; m_iUp = i;
} else {
throw new Exception ("This should never happen!");
}
}
// Check whether one or both classes are missing
if ((m_iUp == -1) || (m_iLow == -1)) {
if (m_iUp != -1) {
m_b = -1;
} else if (m_iLow != -1) {
m_b = 1;
} else {
m_class = null;
return;
}
if (!m_useRBF && m_exponent == 1.0) {
m_sparseWeights = new double[0];
m_sparseIndices = new int[0];
m_class = null;
} else {
m_supportVectors = new SMOset(0);
m_alpha = new double[0];
m_class = new double[0];
}
// Fit sigmoid if requested
if (fitLogistic) {
fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed));
}
return;
}
// Set the reference to the data
m_data = insts;
// If machine is linear, reserve space for weights
if (!m_useRBF && m_exponent == 1.0) {
m_weights = new double[m_data.numAttributes()];
} else {
m_weights = null;
}
// Initialize alpha array to zero
m_alpha = new double[m_data.numInstances()];
// Initialize sets
m_supportVectors = new SMOset(m_data.numInstances());
m_I0 = new SMOset(m_data.numInstances());
m_I1 = new SMOset(m_data.numInstances());
m_I2 = new SMOset(m_data.numInstances());
m_I3 = new SMOset(m_data.numInstances());
m_I4 = new SMOset(m_data.numInstances());
// Clean out some instance variables
m_sparseWeights = null;
m_sparseIndices = null;
// Initialize error cache
m_errors = new double[m_data.numInstances()];
m_errors[m_iLow] = 1; m_errors[m_iUp] = -1;
// Initialize kernel
if(m_useRBF) {
m_kernel = new RBFKernel(m_data, m_cacheSize, m_gamma);
} else {
if (m_featureSpaceNormalization) {
m_kernel = new NormalizedPolyKernel(m_data, m_cacheSize, m_exponent,
m_lowerOrder);
} else {
m_kernel = new PolyKernel(m_data, m_cacheSize, m_exponent, m_lowerOrder);
}
}
// Build up I1 and I4
for (int i = 0; i < m_class.length; i++ ) {
if (m_class[i] == 1) {
m_I1.insert(i);
} else {
m_I4.insert(i);
}
}
// Loop to find all the support vectors
int numChanged = 0;
boolean examineAll = true;
while ((numChanged > 0) || examineAll) {
numChanged = 0;
if (examineAll) {
for (int i = 0; i < m_alpha.length; i++) {
if (examineExample(i)) {
numChanged++;
}
}
} else {
// This code implements Modification 1 from Keerthi et al.'s paper
for (int i = 0; i < m_alpha.length; i++) {
if ((m_alpha[i] > 0) &&
(m_alpha[i] < m_C * m_data.instance(i).weight())) {
if (examineExample(i)) {
numChanged++;
}
// Is optimality on unbound vectors obtained?
if (m_bUp > m_bLow - 2 * m_tol) {
numChanged = 0;
break;
}
}
}
//This is the code for Modification 2 from Keerthi et al.'s paper
/*boolean innerLoopSuccess = true;
numChanged = 0;
while ((m_bUp < m_bLow - 2 * m_tol) && (innerLoopSuccess == true)) {
innerLoopSuccess = takeStep(m_iUp, m_iLow, m_errors[m_iLow]);
}*/
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -