⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    SMO.java
 *    Copyright (C) 1999 Eibe Frank
 *
 */

package weka.classifiers.functions;

import weka.classifiers.functions.supportVector.*;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.Logistic;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.Standardize;
import weka.filters.Filter;
import java.util.*;
import java.io.*;
import weka.core.*;

/**
 * Implements John C. Platt's sequential minimal optimization
 * algorithm for training a support vector classifier using polynomial
 * or RBF kernels. 
 *
 * This implementation globally replaces all missing values and
 * transforms nominal attributes into binary ones. It also
 * normalizes all attributes by default. (Note that the coefficients
 * in the output are based on the normalized/standardized data, not the
 * original data.)
 *
 * Multi-class problems are solved using pairwise classification.
 *
 * To obtain proper probability estimates, use the option that fits
 * logistic regression models to the outputs of the support vector
 * machine. In the multi-class case the predicted probabilities
 * will be coupled using Hastie and Tibshirani's pairwise coupling
 * method.
 *
 * Note: for improved speed standardization should be turned off when
 * operating on SparseInstances.<p>
 *
 * For more information on the SMO algorithm, see<p>
 *
 * J. Platt (1998). <i>Fast Training of Support Vector
 * Machines using Sequential Minimal Optimization</i>. Advances in Kernel
 * Methods - Support Vector Learning, B. Schoelkopf, C. Burges, and
 * A. Smola, eds., MIT Press. <p>
 *
 * S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy, 
 * <i>Improvements to Platt's SMO Algorithm for SVM Classifier Design</i>. 
 * Neural Computation, 13(3), pp 637-649, 2001. <p>
 *
 * Valid options are:<p>
 *
 * -C num <br>
 * The complexity constant C. (default 1)<p>
 *
 * -E num <br>
 * The exponent for the polynomial kernel. (default 1)<p>
 *
 * -G num <br>
 * Gamma for the RBF kernel. (default 0.01)<p>
 *
 * -N <0|1|2> <br>
 * Whether to 0=normalize/1=standardize/2=neither. (default 0=normalize)<p>
 *
 * -F <br>
 * Feature-space normalization (only for non-linear polynomial kernels). <p>
 *
 * -O <br>
 * Use lower-order terms (only for non-linear polynomial kernels). <p>
 *
 * -R <br>
 * Use the RBF kernel. (default poly)<p>
 *
 * -A num <br>
 * Sets the size of the kernel cache. Should be a prime number. 
 * (default 250007, use 0 for full cache) <p>
 *
 * -L num <br>
 * Sets the tolerance parameter. (default 1.0e-3)<p>
 *
 * -P num <br>
 * Sets the epsilon for round-off error. (default 1.0e-12)<p>
 *
 * -M <br>
 * Fit logistic models to SVM outputs.<p>
 *
 * -V num <br>
 * Number of folds for cross-validation used to generate data
 * for logistic models. (default -1, use training data)
 *
 * -W num <br>
 * Random number seed for cross-validation. (default 1)
 *
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @author Shane Legg (shane@intelligenesis.net) (sparse vector code)
 * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code)
 * @version $Revision: 1.1 $ */
public class SMO extends Classifier implements WeightedInstancesHandler {

  /**
   * Returns a string describing classifier
   * @return a description suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {

    return  "Implements John Platt's sequential minimal optimization "
      + "algorithm for training a support vector classifier.\n\n"
      + "This implementation globally replaces all missing values and "
      + "transforms nominal attributes into binary ones. It also "
      + "normalizes all attributes by default. (In that case the coefficients "
      + "in the output are based on the normalized data, not the "
      + "original data --- this is important for interpreting the classifier.)\n\n"
      + "Multi-class problems are solved using pairwise classification.\n\n"
      + "To obtain proper probability estimates, use the option that fits "
      + "logistic regression models to the outputs of the support vector "
      + "machine. In the multi-class case the predicted probabilities "
      + "are coupled using Hastie and Tibshirani's pairwise coupling "
      + "method.\n\n"
      + "Note: for improved speed normalization should be turned off when "
      + "operating on SparseInstances.\n\n"
      + "For more information on the SMO algorithm, see\n\n"
      + "J. Platt (1998). \"Fast Training of Support Vector "
      + "Machines using Sequential Minimal Optimization\". Advances in Kernel "
      + "Methods - Support Vector Learning, B. Schoelkopf, C. Burges, and "
      + "A. Smola, eds., MIT Press. \n\n"
      + "S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy,  "
      + "\"Improvements to Platt's SMO Algorithm for SVM Classifier Design\".  "
      + "Neural Computation, 13(3), pp 637-649, 2001.";
  }

  /**
   * Class for building a binary support vector machine.
   */
  protected class BinarySMO implements Serializable {
    
    /** The Lagrange multipliers. */
    protected double[] m_alpha;

    /** The thresholds. */
    protected double m_b, m_bLow, m_bUp;

    /** The indices for m_bLow and m_bUp */
    protected int m_iLow, m_iUp;

    /** The training data. */
    protected Instances m_data;

    /** Weight vector for linear machine. */
    protected double[] m_weights;

    /** Variables to hold weight vector in sparse form.
	(To reduce storage requirements.) */
    protected double[] m_sparseWeights;
    protected int[] m_sparseIndices;

    /** Kernel to use **/
    protected Kernel m_kernel;

    /** The transformed class values. */
    protected double[] m_class;

    /** The current set of errors for all non-bound examples. */
    protected double[] m_errors;

    /** The five different sets used by the algorithm. */
    protected SMOset m_I0; // {i: 0 < m_alpha[i] < C}
    protected SMOset m_I1; // {i: m_class[i] = 1, m_alpha[i] = 0}
    protected SMOset m_I2; // {i: m_class[i] = -1, m_alpha[i] =C}
    protected SMOset m_I3; // {i: m_class[i] = 1, m_alpha[i] = C}
    protected SMOset m_I4; // {i: m_class[i] = -1, m_alpha[i] = 0}

    /** The set of support vectors */
    protected SMOset m_supportVectors; // {i: 0 < m_alpha[i]}

    /** Stores logistic regression model for probability estimate */
    protected Logistic m_logistic = null;

    /** Stores the weight of the training instances */
    protected double m_sumOfWeights = 0;

    /**
     * Fits logistic regression model to SVM outputs analogue
     * to John Platt's method.  
     *
     * @param insts the set of training instances
     * @param cl1 the first class' index
     * @param cl2 the second class' index
     * @exception Exception if the sigmoid can't be fit successfully
     */
    protected void fitLogistic(Instances insts, int cl1, int cl2,
			     int numFolds, Random random) 
      throws Exception {

      // Create header of instances object
      FastVector atts = new FastVector(2);
      atts.addElement(new Attribute("pred"));
      FastVector attVals = new FastVector(2);
      attVals.addElement(insts.classAttribute().value(cl1));
      attVals.addElement(insts.classAttribute().value(cl2));
      atts.addElement(new Attribute("class", attVals));
      Instances data = new Instances("data", atts, insts.numInstances());
      data.setClassIndex(1);

      // Collect data for fitting the logistic model
      if (numFolds <= 0) {

	// Use training data
	for (int j = 0; j < insts.numInstances(); j++) {
	  Instance inst = insts.instance(j);
	  double[] vals = new double[2];
	  vals[0] = SVMOutput(-1, inst);
	  if (inst.classValue() == cl2) {
	    vals[1] = 1;
	  }
	  data.add(new Instance(inst.weight(), vals));
	}
      } else {

	// Check whether number of folds too large
	if (numFolds > insts.numInstances()) {
	  numFolds = insts.numInstances();
	}

	// Make copy of instances because we will shuffle them around
	insts = new Instances(insts);
	
	// Perform three-fold cross-validation to collect
	// unbiased predictions
	insts.randomize(random);
	insts.stratify(numFolds);
	for (int i = 0; i < numFolds; i++) {
	  Instances train = insts.trainCV(numFolds, i, random);
	  SerializedObject so = new SerializedObject(this);
	  BinarySMO smo = (BinarySMO)so.getObject();
	  smo.buildClassifier(train, cl1, cl2, false, -1, -1);
	  Instances test = insts.testCV(numFolds, i);
	  for (int j = 0; j < test.numInstances(); j++) {
	    double[] vals = new double[2];
	    vals[0] = smo.SVMOutput(-1, test.instance(j));
	    if (test.instance(j).classValue() == cl2) {
	      vals[1] = 1;
	    }
	    data.add(new Instance(test.instance(j).weight(), vals));
	  }
	}
      }

      // Build logistic regression model
      m_logistic = new Logistic();
      m_logistic.buildClassifier(data);
    }

    /**
     * Method for building the binary classifier.
     *
     * @param insts the set of training instances
     * @param cl1 the first class' index
     * @param cl2 the second class' index
     * @param fitLogistic true if logistic model is to be fit
     * @param numFolds number of folds for internal cross-validation
     * @param random random number generator for cross-validation
     * @exception Exception if the classifier can't be built successfully
     */
    protected void buildClassifier(Instances insts, int cl1, int cl2,
				 boolean fitLogistic, int numFolds,
				 int randomSeed) throws Exception {
      
      // Initialize some variables
      m_bUp = -1; m_bLow = 1; m_b = 0; 
      m_alpha = null; m_data = null; m_weights = null; m_errors = null;
      m_logistic = null; m_I0 = null; m_I1 = null; m_I2 = null;
      m_I3 = null; m_I4 = null;	m_sparseWeights = null; m_sparseIndices = null;

      // Store the sum of weights
      m_sumOfWeights = insts.sumOfWeights();
      
      // Set class values
      m_class = new double[insts.numInstances()];
      m_iUp = -1; m_iLow = -1;
      for (int i = 0; i < m_class.length; i++) {
	if ((int) insts.instance(i).classValue() == cl1) {
	  m_class[i] = -1; m_iLow = i;
	} else if ((int) insts.instance(i).classValue() == cl2) {
	  m_class[i] = 1; m_iUp = i;
	} else {
	  throw new Exception ("This should never happen!");
	}
      }

      // Check whether one or both classes are missing
      if ((m_iUp == -1) || (m_iLow == -1)) {
	if (m_iUp != -1) {
	  m_b = -1;
	} else if (m_iLow != -1) {
	  m_b = 1;
	} else {
	  m_class = null;
	  return;
	}
	if (!m_useRBF && m_exponent == 1.0) {
	  m_sparseWeights = new double[0];
	  m_sparseIndices = new int[0];
	  m_class = null;
	} else {
	  m_supportVectors = new SMOset(0);
	  m_alpha = new double[0];
	  m_class = new double[0];
	}

	// Fit sigmoid if requested
	if (fitLogistic) {
	  fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed));
	}
	return;
      }
      
      // Set the reference to the data
      m_data = insts;

      // If machine is linear, reserve space for weights
      if (!m_useRBF && m_exponent == 1.0) {
	m_weights = new double[m_data.numAttributes()];
      } else {
	m_weights = null;
      }
      
      // Initialize alpha array to zero
      m_alpha = new double[m_data.numInstances()];
      
      // Initialize sets
      m_supportVectors = new SMOset(m_data.numInstances());
      m_I0 = new SMOset(m_data.numInstances());
      m_I1 = new SMOset(m_data.numInstances());
      m_I2 = new SMOset(m_data.numInstances());
      m_I3 = new SMOset(m_data.numInstances());
      m_I4 = new SMOset(m_data.numInstances());

      // Clean out some instance variables
      m_sparseWeights = null;
      m_sparseIndices = null;
      
      // Initialize error cache
      m_errors = new double[m_data.numInstances()];
      m_errors[m_iLow] = 1; m_errors[m_iUp] = -1;
     
      // Initialize kernel
      if(m_useRBF) {
	m_kernel = new RBFKernel(m_data, m_cacheSize, m_gamma);
      } else {
	if (m_featureSpaceNormalization) {
	  m_kernel = new NormalizedPolyKernel(m_data, m_cacheSize, m_exponent, 
					      m_lowerOrder);
	} else {
	  m_kernel = new PolyKernel(m_data, m_cacheSize, m_exponent, m_lowerOrder);
	}
      }
      
      // Build up I1 and I4
      for (int i = 0; i < m_class.length; i++ ) {
	if (m_class[i] == 1) {
	  m_I1.insert(i);
	} else {
	  m_I4.insert(i);
	}
      }
      
      // Loop to find all the support vectors
      int numChanged = 0;
      boolean examineAll = true;
      while ((numChanged > 0) || examineAll) {
	numChanged = 0;
	if (examineAll) {
	  for (int i = 0; i < m_alpha.length; i++) {
	    if (examineExample(i)) {
	      numChanged++;
	    }
	  }
	} else {
	  
	  // This code implements Modification 1 from Keerthi et al.'s paper
	  for (int i = 0; i < m_alpha.length; i++) {
	    if ((m_alpha[i] > 0) &&  
		(m_alpha[i] < m_C * m_data.instance(i).weight())) {
	      if (examineExample(i)) {
		numChanged++;
	      }
	      
	      // Is optimality on unbound vectors obtained?
	      if (m_bUp > m_bLow - 2 * m_tol) {
		numChanged = 0;
		break;
	      }
	    }
	  }
	  
	  //This is the code for Modification 2 from Keerthi et al.'s paper
	  /*boolean innerLoopSuccess = true; 
	    numChanged = 0;
	    while ((m_bUp < m_bLow - 2 * m_tol) && (innerLoopSuccess == true)) {
	    innerLoopSuccess = takeStep(m_iUp, m_iLow, m_errors[m_iLow]);
	    }*/
	}
	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -