📄 smo.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SMO.java * Copyright (C) 1999 Eibe Frank * */package weka.classifiers.functions;import weka.classifiers.Classifier;import weka.classifiers.DistributionClassifier;import weka.classifiers.Evaluation;import weka.filters.unsupervised.attribute.NominalToBinary;import weka.filters.unsupervised.attribute.ReplaceMissingValues;import weka.filters.unsupervised.attribute.Normalize;import weka.filters.unsupervised.attribute.Standardize;import weka.filters.Filter;import java.util.*;import java.io.*;import weka.core.*;/** * Implements John C. Platt's sequential minimal optimization * algorithm for training a support vector classifier using polynomial * or RBF kernels. * * This implementation globally replaces all missing values and * transforms nominal attributes into binary ones. It also * normalizes all attributes by default. (Note that the coefficients * in the output are based on the normalized/standardized data, not the * original data.) * * Multi-class problems are solved using pairwise classification. * * To obtain proper probability estimates, use the option that fits * logistic regression models to the outputs of the support vector * machine. In the multi-class case the predicted probabilities * will be coupled using Hastie and Tibshirani's pairwise coupling * method. * * Note: for improved speed standardization should be turned off when * operating on SparseInstances.<p> * * For more information on the SMO algorithm, see<p> * * J. Platt (1998). <i>Fast Training of Support Vector * Machines using Sequential Minimal Optimization</i>. Advances in Kernel * Methods - Support Vector Learning, B. Sch鰈kopf, C. Burges, and * A. Smola, eds., MIT Press. <p> * * S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy, * <i>Improvements to Platt's SMO Algorithm for SVM Classifier Design</i>. * Neural Computation, 13(3), pp 637-649, 2001. <p> * * Valid options are:<p> * * -C num <br> * The complexity constant C. (default 1)<p> * * -E num <br> * The exponent for the polynomial kernel. (default 1)<p> * * -G num <br> * Gamma for the RBF kernel. (default 0.01)<p> * * -N <0|1|2> <br> * Whether to 0=normalize/1=standardize/2=neither. (default 0=normalize)<p> * * -F <br> * Feature-space normalization (only for non-linear polynomial kernels). <p> * * -O <br> * Use lower-order terms (only for non-linear polynomial kernels). <p> * * -R <br> * Use the RBF kernel. (default poly)<p> * * -A num <br> * Sets the size of the kernel cache. Should be a prime number. * (default 1000003) <p> * * -T num <br> * Sets the tolerance parameter. (default 1.0e-3)<p> * * -P num <br> * Sets the epsilon for round-off error. (default 1.0e-12)<p> * * -M <br> * Fit logistic models to SVM outputs.<p> * * -V num <br> * Number of runs for cross-validation used to generate data * for logistic models. (default -1, use training data) * * -W num <br> * Random number seed for cross-validation. (default 1) * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Shane Legg (shane@intelligenesis.net) (sparse vector code) * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code) * @author J. Lindgren (jtlindgr{at}cs.helsinki.fi) (RBF kernel) * @version $Revision: 1.1.1.1 $ */public class SMO extends DistributionClassifier implements OptionHandler, WeightedInstancesHandler { /** * Class for building a binary support vector machine. */ private class BinarySMO implements Serializable { /** * Calculates a dot product between two instances */ private double dotProd(Instance inst1, Instance inst2) throws Exception { double result=0; // we can do a fast dot product int n1 = inst1.numValues(); int n2 = inst2.numValues(); int classIndex = m_data.classIndex(); for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) { int ind1 = inst1.index(p1); int ind2 = inst2.index(p2); if (ind1 == ind2) { if (ind1 != classIndex) { result += inst1.valueSparse(p1) * inst2.valueSparse(p2); } p1++; p2++; } else if (ind1 > ind2) { p2++; } else { p1++; } } return(result); } /** * Abstract kernel. */ private abstract class Kernel implements Serializable { /** * Computes the result of the kernel function for two instances. * * @param id1 the index of the first instance * @param id2 the index of the second instance * @param inst the instance corresponding to id1 * @return the result of the kernel function */ public abstract double eval(int id1, int id2, Instance inst1) throws Exception; } /** * The normalized polynomial kernel. */ private class NormalizedPolyKernel extends Kernel { PolyKernel polyKernel = new PolyKernel(); public double eval(int id1, int id2, Instance inst1) throws Exception { return polyKernel.eval(id1, id2, inst1) / Math.sqrt(polyKernel.eval(id1, id1, inst1) * polyKernel.eval(id2, id2, m_data.instance(id2))); } } /** * The polynomial kernel. */ private class PolyKernel extends Kernel { public double eval(int id1, int id2, Instance inst1) throws Exception { double result = 0; long key = -1; int location = -1; // we can only cache if we know the indexes if ((id1 >= 0) && (m_keys != null)) { if (id1 > id2) { key = (long)id1 * m_alpha.length + id2; } else { key = (long)id2 * m_alpha.length + id1; } if (key < 0) { throw new Exception("Cache overflow detected!"); } location = (int)(key % m_keys.length); if (m_keys[location] == (key + 1)) { return m_storage[location]; } } if (id1 == id2) { result = dotProd(inst1, inst1); } else { result = dotProd(inst1, m_data.instance(id2)); } // Use lower order terms? if (m_lowerOrder) { result += 1.0; } if (m_exponent != 1.0) { result = Math.pow(result, m_exponent); } m_kernelEvals++; // store result in cache if (key != -1){ m_storage[location] = result; m_keys[location] = (key + 1); } return result; } } /** * The RBF kernel. */ private class RBFKernel extends Kernel { /** The precalculated dotproducts of <inst_i,inst_i> */ private double m_kernelPrecalc[]; /** * Constructor. Initializes m_kernelPrecalc[]. */ public RBFKernel(Instances data) throws Exception { m_kernelPrecalc=new double[data.numInstances()]; for(int i=0;i<data.numInstances();i++) m_kernelPrecalc[i]=dotProd(data.instance(i),data.instance(i)); } public double eval(int id1, int id2, Instance inst1) throws Exception { double result = 0; long key = -1; int location = -1; // we can only cache if we know the indexes if (id1 >= 0) { if (id1 > id2) { key = (long)id1 * m_alpha.length + id2; } else { key = (long)id2 * m_alpha.length + id1; } if (key < 0) { throw new Exception("Cache overflow detected!"); } location = (int)(key % m_keys.length); if (m_keys[location] == (key + 1)) { return m_storage[location]; } } Instance inst2 = m_data.instance(id2); double precalc1; if(id1 == -1) precalc1 = dotProd(inst1,inst1); else precalc1 = m_kernelPrecalc[id1]; result = Math.exp(m_gamma * (2. * dotProd(inst1, inst2) - precalc1 - m_kernelPrecalc[id2])); m_kernelEvals++; // store result in cache if (key != -1){ m_storage[location] = result; m_keys[location] = (key + 1); } return result; } } /** * Stores a set of a given size. */ private class SMOset implements Serializable { /** The current number of elements in the set */ private int m_number; /** The first element in the set */ private int m_first; /** Indicators */ private boolean[] m_indicators; /** The next element for each element */ private int[] m_next; /** The previous element for each element */ private int[] m_previous; /** * Creates a new set of the given size. */ private SMOset(int size) { m_indicators = new boolean[size]; m_next = new int[size]; m_previous = new int[size]; m_number = 0; m_first = -1; } /** * Checks whether an element is in the set. */ private boolean contains(int index) { return m_indicators[index]; } /** * Deletes an element from the set. */ private void delete(int index) { if (m_indicators[index]) { if (m_first == index) { m_first = m_next[index]; } else { m_next[m_previous[index]] = m_next[index]; } if (m_next[index] != -1) { m_previous[m_next[index]] = m_previous[index]; } m_indicators[index] = false; m_number--; } } /** * Inserts an element into the set. */ private void insert(int index) { if (!m_indicators[index]) { if (m_number == 0) { m_first = index; m_next[index] = -1; m_previous[index] = -1; } else { m_previous[m_first] = index; m_next[index] = m_first; m_previous[index] = -1; m_first = index; } m_indicators[index] = true; m_number++; } } /** * Gets the next element in the set. -1 gets the first one. */ private int getNext(int index) { if (index == -1) { return m_first; } else { return m_next[index]; } } /** * Prints all the current elements in the set. */ private void printElements() { for (int i = getNext(-1); i != -1; i = getNext(i)) { System.err.print(i + " "); } System.err.println(); for (int i = 0; i < m_indicators.length; i++) { if (m_indicators[i]) { System.err.print(i + " "); } } System.err.println(); System.err.println(m_number); } /** * Returns the number of elements in the set. */ private int numElements() { return m_number; } } /** The Lagrange multipliers. */ private double[] m_alpha; /** The thresholds. */ private double m_b, m_bLow, m_bUp; /** The indices for m_bLow and m_bUp */ private int m_iLow, m_iUp; /** The training data. */ private Instances m_data; /** Weight vector for linear machine. */ private double[] m_weights; /** Variables to hold weight vector in sparse form. (To reduce storage requirements.) */ private double[] m_sparseWeights; private int[] m_sparseIndices; /** Kernel to use **/ private Kernel m_kernel; /** Kernel function cache */ private double[] m_storage; private long[] m_keys; /** The transformed class values. */ private double[] m_class; /** The current set of errors for all non-bound examples. */ private double[] m_errors; /** The five different sets used by the algorithm. */ private SMOset m_I0; // {i: 0 < m_alpha[i] < C} private SMOset m_I1; // {i: m_class[i] = 1, m_alpha[i] = 0} private SMOset m_I2; // {i: m_class[i] = -1, m_alpha[i] =C} private SMOset m_I3; // {i: m_class[i] = 1, m_alpha[i] = C} private SMOset m_I4; // {i: m_class[i] = -1, m_alpha[i] = 0} /** The set of support vectors */ private SMOset m_supportVectors; // {i: 0 < m_alpha[i]} /** Counts the number of kernel evaluations. */ private int m_kernelEvals; /** Stores logistic regression model for probability estimate */ private Logistic m_logistic = null; /** Stores the weight of the training instances */ private double m_sumOfWeights = 0; /** * Fits logistic regression model to SVM outputs analogue * to John Platt's method. * * @param insts the set of training instances * @param cl1 the first class' index * @param cl2 the second class' index * @exception Exception if the sigmoid can't be fit successfully */ private void fitLogistic(Instances insts, int cl1, int cl2, int numFolds, Random random) throws Exception { // Create header of instances object FastVector atts = new FastVector(2); atts.addElement(new Attribute("pred")); FastVector attVals = new FastVector(2); attVals.addElement(insts.classAttribute().value(cl1)); attVals.addElement(insts.classAttribute().value(cl2)); atts.addElement(new Attribute("class", attVals)); Instances data = new Instances("data", atts, insts.numInstances()); data.setClassIndex(1); // Collect data for fitting the logistic model if (numFolds <= 0) { // Use training data for (int j = 0; j < insts.numInstances(); j++) { Instance inst = insts.instance(j); double[] vals = new double[2]; vals[0] = SVMOutput(-1, inst); if (inst.classValue() == cl2) { vals[1] = 1; } data.add(new Instance(inst.weight(), vals)); } } else { // Check whether number of folds too large if (numFolds > insts.numInstances()) { numFolds = insts.numInstances(); } // Make copy of instances because we will shuffle them around insts = new Instances(insts); // Perform three-fold cross-validation to collect // unbiased predictions insts.randomize(random); insts.stratify(numFolds); for (int i = 0; i < numFolds; i++) { Instances train = insts.trainCV(numFolds, i); SerializedObject so = new SerializedObject(this); BinarySMO smo = (BinarySMO)so.getObject(); smo.buildClassifier(train, cl1, cl2, false, -1, -1); Instances test = insts.testCV(numFolds, i); for (int j = 0; j < test.numInstances(); j++) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -