⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    SMO.java *    Copyright (C) 1999 Eibe Frank * */package weka.classifiers.functions;import weka.classifiers.Classifier;import weka.classifiers.DistributionClassifier;import weka.classifiers.Evaluation;import weka.filters.unsupervised.attribute.NominalToBinary;import weka.filters.unsupervised.attribute.ReplaceMissingValues;import weka.filters.unsupervised.attribute.Normalize;import weka.filters.unsupervised.attribute.Standardize;import weka.filters.Filter;import java.util.*;import java.io.*;import weka.core.*;/** * Implements John C. Platt's sequential minimal optimization * algorithm for training a support vector classifier using polynomial * or RBF kernels.  * * This implementation globally replaces all missing values and * transforms nominal attributes into binary ones. It also * normalizes all attributes by default. (Note that the coefficients * in the output are based on the normalized/standardized data, not the * original data.) * * Multi-class problems are solved using pairwise classification. * * To obtain proper probability estimates, use the option that fits * logistic regression models to the outputs of the support vector * machine. In the multi-class case the predicted probabilities * will be coupled using Hastie and Tibshirani's pairwise coupling * method. * * Note: for improved speed standardization should be turned off when * operating on SparseInstances.<p> * * For more information on the SMO algorithm, see<p> * * J. Platt (1998). <i>Fast Training of Support Vector * Machines using Sequential Minimal Optimization</i>. Advances in Kernel * Methods - Support Vector Learning, B. Sch鰈kopf, C. Burges, and * A. Smola, eds., MIT Press. <p> * * S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy,  * <i>Improvements to Platt's SMO Algorithm for SVM Classifier Design</i>.  * Neural Computation, 13(3), pp 637-649, 2001. <p> * * Valid options are:<p> * * -C num <br> * The complexity constant C. (default 1)<p> * * -E num <br> * The exponent for the polynomial kernel. (default 1)<p> * * -G num <br> * Gamma for the RBF kernel. (default 0.01)<p> * * -N <0|1|2> <br> * Whether to 0=normalize/1=standardize/2=neither. (default 0=normalize)<p> * * -F <br> * Feature-space normalization (only for non-linear polynomial kernels). <p> * * -O <br> * Use lower-order terms (only for non-linear polynomial kernels). <p> * * -R <br> * Use the RBF kernel. (default poly)<p> * * -A num <br> * Sets the size of the kernel cache. Should be a prime number.  * (default 1000003) <p> * * -T num <br> * Sets the tolerance parameter. (default 1.0e-3)<p> * * -P num <br> * Sets the epsilon for round-off error. (default 1.0e-12)<p> * * -M <br> * Fit logistic models to SVM outputs.<p> * * -V num <br> * Number of runs for cross-validation used to generate data * for logistic models. (default -1, use training data) * * -W num <br> * Random number seed for cross-validation. (default 1) * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Shane Legg (shane@intelligenesis.net) (sparse vector code) * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code) * @author J. Lindgren (jtlindgr{at}cs.helsinki.fi) (RBF kernel) * @version $Revision: 1.1.1.1 $ */public class SMO extends DistributionClassifier implements OptionHandler, 					       WeightedInstancesHandler {  /**   * Class for building a binary support vector machine.   */  private class BinarySMO implements Serializable {    /**     * Calculates a dot product between two instances     */    private double dotProd(Instance inst1, Instance inst2)       throws Exception {            double result=0;          // we can do a fast dot product      int n1 = inst1.numValues(); int n2 = inst2.numValues();      int classIndex = m_data.classIndex();      for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {        int ind1 = inst1.index(p1);         int ind2 = inst2.index(p2);        if (ind1 == ind2) {	  if (ind1 != classIndex) {	    result += inst1.valueSparse(p1) * inst2.valueSparse(p2);	  }	  p1++; p2++;	} else if (ind1 > ind2) {	  p2++;        } else {          p1++;	}      }      return(result);    }    /**     * Abstract kernel.      */    private abstract class Kernel implements Serializable {          /**       * Computes the result of the kernel function for two instances.       *       * @param id1 the index of the first instance       * @param id2 the index of the second instance       * @param inst the instance corresponding to id1       * @return the result of the kernel function       */      public abstract double eval(int id1, int id2, Instance inst1) 	throws Exception;    }     /**     * The normalized polynomial kernel.     */    private class NormalizedPolyKernel extends Kernel {      PolyKernel polyKernel = new PolyKernel();      public double eval(int id1, int id2, Instance inst1) 	throws Exception {	return polyKernel.eval(id1, id2, inst1) /	  Math.sqrt(polyKernel.eval(id1, id1, inst1) * 		    polyKernel.eval(id2, id2, m_data.instance(id2)));      }    }    /**     * The polynomial kernel.     */    private class PolyKernel extends Kernel {      public double eval(int id1, int id2, Instance inst1) 	throws Exception {      	double result = 0;	long key = -1;	int location = -1;	// we can only cache if we know the indexes	if ((id1 >= 0) && (m_keys != null)) {	  if (id1 > id2) {	    key = (long)id1 * m_alpha.length + id2;	  } else {	    key = (long)id2 * m_alpha.length + id1;	  }	  if (key < 0) {	    throw new Exception("Cache overflow detected!");	  }	  location = (int)(key % m_keys.length);	  if (m_keys[location] == (key + 1)) {	    return m_storage[location];	  }        }		if (id1 == id2) {	  result = dotProd(inst1, inst1);	} else {	  result = dotProd(inst1, m_data.instance(id2));	}            // Use lower order terms?        if (m_lowerOrder) { 	  result += 1.0;        }            if (m_exponent != 1.0) {	  result = Math.pow(result, m_exponent);        }        m_kernelEvals++;            // store result in cache 	        if (key != -1){	  m_storage[location] = result;	  m_keys[location] = (key + 1);        }        return result;      }    }        /**     * The RBF kernel.     */    private class RBFKernel extends Kernel {          /** The precalculated dotproducts of <inst_i,inst_i> */      private double m_kernelPrecalc[];      /**       * Constructor. Initializes m_kernelPrecalc[].       */      public RBFKernel(Instances data) throws Exception {        	m_kernelPrecalc=new double[data.numInstances()];	for(int i=0;i<data.numInstances();i++)	  m_kernelPrecalc[i]=dotProd(data.instance(i),data.instance(i));            }      public double eval(int id1, int id2, Instance inst1) 	throws Exception {  	double result = 0;	long key = -1;	int location = -1;      	// we can only cache if we know the indexes	if (id1 >= 0) {	  if (id1 > id2) {  	    key = (long)id1 * m_alpha.length + id2;	  } else {	    key = (long)id2 * m_alpha.length + id1;	  }	  if (key < 0) {	    throw new Exception("Cache overflow detected!");	  }	  location = (int)(key % m_keys.length);	  if (m_keys[location] == (key + 1)) {	    return m_storage[location];	  }        }		Instance inst2 = m_data.instance(id2);	double precalc1;	if(id1 == -1)	  precalc1 = dotProd(inst1,inst1);	else          precalc1 = m_kernelPrecalc[id1];	        result = Math.exp(m_gamma * (2. * dotProd(inst1, inst2) -				     precalc1 - m_kernelPrecalc[id2]));          m_kernelEvals++;            // store result in cache 	        if (key != -1){          m_storage[location] = result;          m_keys[location] = (key + 1);        }        return result;      }    }        /**     * Stores a set of a given size.     */    private class SMOset implements Serializable {      /** The current number of elements in the set */      private int m_number;      /** The first element in the set */      private int m_first;      /** Indicators */      private boolean[] m_indicators;      /** The next element for each element */      private int[] m_next;      /** The previous element for each element */      private int[] m_previous;      /**       * Creates a new set of the given size.       */      private SMOset(int size) {      	m_indicators = new boolean[size];	m_next = new int[size];	m_previous = new int[size];	m_number = 0;	m_first = -1;      }       /**       * Checks whether an element is in the set.       */      private boolean contains(int index) {	return m_indicators[index];      }      /**       * Deletes an element from the set.       */      private void delete(int index) {	if (m_indicators[index]) {	  if (m_first == index) {	    m_first = m_next[index];	  } else {	    m_next[m_previous[index]] = m_next[index];	  }	  if (m_next[index] != -1) {	    m_previous[m_next[index]] = m_previous[index];	  }	  m_indicators[index] = false;	  m_number--;	}      }      /**       * Inserts an element into the set.       */      private void insert(int index) {	if (!m_indicators[index]) {	  if (m_number == 0) {	    m_first = index;	    m_next[index] = -1;	    m_previous[index] = -1;	  } else {	    m_previous[m_first] = index;	    m_next[index] = m_first;	    m_previous[index] = -1;	    m_first = index;	  }	  m_indicators[index] = true;	  m_number++;	}      }      /**        * Gets the next element in the set. -1 gets the first one.       */      private int getNext(int index) {	if (index == -1) {	  return m_first;	} else {	  return m_next[index];	}      }      /**       * Prints all the current elements in the set.       */      private void printElements() {	for (int i = getNext(-1); i != -1; i = getNext(i)) {	  System.err.print(i + " ");	}	System.err.println();	for (int i = 0; i < m_indicators.length; i++) {	  if (m_indicators[i]) {	    System.err.print(i + " ");	  }	}	System.err.println();	System.err.println(m_number);      }      /**        * Returns the number of elements in the set.       */      private int numElements() {      	return m_number;      }    }    /** The Lagrange multipliers. */    private double[] m_alpha;    /** The thresholds. */    private double m_b, m_bLow, m_bUp;    /** The indices for m_bLow and m_bUp */    private int m_iLow, m_iUp;    /** The training data. */    private Instances m_data;    /** Weight vector for linear machine. */    private double[] m_weights;    /** Variables to hold weight vector in sparse form.	(To reduce storage requirements.) */    private double[] m_sparseWeights;    private int[] m_sparseIndices;    /** Kernel to use **/    private Kernel m_kernel;    /** Kernel function cache */    private double[] m_storage;    private long[] m_keys;    /** The transformed class values. */    private double[] m_class;    /** The current set of errors for all non-bound examples. */    private double[] m_errors;    /** The five different sets used by the algorithm. */    private SMOset m_I0; // {i: 0 < m_alpha[i] < C}    private SMOset m_I1; // {i: m_class[i] = 1, m_alpha[i] = 0}    private SMOset m_I2; // {i: m_class[i] = -1, m_alpha[i] =C}    private SMOset m_I3; // {i: m_class[i] = 1, m_alpha[i] = C}    private SMOset m_I4; // {i: m_class[i] = -1, m_alpha[i] = 0}    /** The set of support vectors */    private SMOset m_supportVectors; // {i: 0 < m_alpha[i]}    /** Counts the number of kernel evaluations. */    private int m_kernelEvals;    /** Stores logistic regression model for probability estimate */    private Logistic m_logistic = null;    /** Stores the weight of the training instances */    private double m_sumOfWeights = 0;    /**     * Fits logistic regression model to SVM outputs analogue     * to John Platt's method.       *     * @param insts the set of training instances     * @param cl1 the first class' index     * @param cl2 the second class' index     * @exception Exception if the sigmoid can't be fit successfully     */    private void fitLogistic(Instances insts, int cl1, int cl2,			     int numFolds, Random random)       throws Exception {      // Create header of instances object      FastVector atts = new FastVector(2);      atts.addElement(new Attribute("pred"));      FastVector attVals = new FastVector(2);      attVals.addElement(insts.classAttribute().value(cl1));      attVals.addElement(insts.classAttribute().value(cl2));      atts.addElement(new Attribute("class", attVals));      Instances data = new Instances("data", atts, insts.numInstances());      data.setClassIndex(1);      // Collect data for fitting the logistic model      if (numFolds <= 0) {	// Use training data	for (int j = 0; j < insts.numInstances(); j++) {	  Instance inst = insts.instance(j);	  double[] vals = new double[2];	  vals[0] = SVMOutput(-1, inst);	  if (inst.classValue() == cl2) {	    vals[1] = 1;	  }	  data.add(new Instance(inst.weight(), vals));	}      } else {	// Check whether number of folds too large	if (numFolds > insts.numInstances()) {	  numFolds = insts.numInstances();	}	// Make copy of instances because we will shuffle them around	insts = new Instances(insts);		// Perform three-fold cross-validation to collect	// unbiased predictions	insts.randomize(random);	insts.stratify(numFolds);	for (int i = 0; i < numFolds; i++) {	  Instances train = insts.trainCV(numFolds, i);	  SerializedObject so = new SerializedObject(this);	  BinarySMO smo = (BinarySMO)so.getObject();	  smo.buildClassifier(train, cl1, cl2, false, -1, -1);	  Instances test = insts.testCV(numFolds, i);	  for (int j = 0; j < test.numInstances(); j++) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -