📄 svmlight.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SVMlight.java * Copyright (C) 2002-2003 Mikhail Bilenko * * TODO: * - implement UpdateableClassifier * - implement the remaining options for SVM-light * - implement WeightedInstancesHandler * - proper conversion from margin to distribution (see Zadrozny & Elkan, Wahba, Platt...) */package weka.classifiers.sparse;import weka.classifiers.Classifier;import weka.classifiers.DistributionClassifier;import weka.classifiers.Evaluation;import weka.classifiers.UpdateableClassifier;import java.io.*;import java.util.*;import weka.core.*;/** * <i> A wrapper for SVMlight package by Thorsten Joachims * For more information, see <p> * * http://www.cs.cornell.edu/People/tj/svm_light * * Valid options are:<p> * * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.9 $ */public class SVMlight extends DistributionClassifier implements OptionHandler { /** The training instances used for classification. */ protected Instances m_train; /** Has the SVM been trained */ protected boolean m_svmTrained = false; /** Output debugging information */ protected boolean m_debug = false; /** Path to the directory where SVM-light executables are located */ protected String m_binPath = new String("/u/ml/software/svm_light/"); /** Path to the directory where temporary files will be stored */ protected String m_tempDirPath = new String("/var/local/tmp/"); protected File m_tempDirFile = null; /** Name of the temporary file where training data will be dumped temporarily */ protected String m_trainFilenameBase = new String("SVMtrain"); protected String m_trainFilename = null; /** Name of the temporary file where a test instance is dumped if buffered IO is not used */ protected String m_testFilenameBase = new String("SVMtest"); protected String m_testFilename = null; /** Name of the file where a model will be temporarily created*/ protected String m_modelFilenameBase = new String("SVMmodel"); protected String m_modelFilename = null; /** Name of the file where predictions will be temporarily stored unless buffered IO is used*/ protected String m_predictionFilenameBase = new String("SVMpredict"); protected String m_predictionFilename = null; /** SVM-light predictions are positive or negative margins; to convert * to a distribution we need min/max margin values... */ protected double m_maxMargin = -45; protected double m_minMargin = 45; protected boolean m_autoBounds = false; /** Is classification done via temporary files or via a buffer? */ protected boolean m_bufferedMode = true; protected BufferedReader m_procReader = null; protected BufferedWriter m_procWriter = null; /**********************/ /** SVM-light options */ /** verbosity level */ protected int m_verbosityLevel = 1; /** SVM-light can work in classification, regression and preference ranking modes */ public static final int SVM_MODE_CLASSIFICATION = 1; public static final int SVM_MODE_REGRESSION = 2; public static final int SVM_MODE_PREFERENCE_RANKING = 4; public static final Tag[] TAGS_SVM_MODE = { new Tag(SVM_MODE_CLASSIFICATION, "Classification"), new Tag(SVM_MODE_REGRESSION, "Regression"), new Tag(SVM_MODE_PREFERENCE_RANKING, "Preference ranking") }; protected int m_mode = SVM_MODE_CLASSIFICATION; /** trade-off between training error and margin (default 0 corresponds to [avg. x*x]^-1) */ protected double m_C = 0; /** Epsilon width of tube for regression */ protected double m_width = 0.1; /** Cost: cost-factor, by which training errors on positive examples outweight errors on negative examples */ protected double m_costFactor = 1; /** Use biased hyperplane (i.e. x*w+b>0) instead of unbiased hyperplane (i.e. x*w>0) */ protected boolean m_biased = true; /** remove inconsistent training examples and retrain */ protected boolean m_removeInconsistentExamples = false; /** Kernel type */ public static final int KERNEL_LINEAR = 1; public static final int KERNEL_POLYNOMIAL = 2; public static final int KERNEL_RBF = 4; public static final int KERNEL_SIGMOID_TANH = 8; public static final Tag[] TAGS_KERNEL_TYPE = { new Tag(KERNEL_LINEAR, "Linear"), new Tag(KERNEL_POLYNOMIAL, "Polynomial (s a*b+c)^d"), new Tag(KERNEL_RBF, "Radial basis function exp(-gamma ||a-b||^2)"), new Tag(KERNEL_SIGMOID_TANH, "Sigmoid tanh(s a*b + c)") }; protected int m_kernelType = KERNEL_RBF; /** Parameter d in polynomial kernel */ protected int m_d = 3; /** Parameter gamma in rbf kernel */ protected double m_gamma = 1; /** Parameter s in sigmoid/polynomial kernel */ protected double m_s = 1; /** parameter c in sigmoid/poly kernel */ protected double m_c1 = 1; /** A default constructor */ public SVMlight() { } /** Take care of closing the SVM-light process before the object is destroyed */ protected void finalize() { cleanupIO(); } /** The buffered version of SVM-light needs to release some I/O resources * before exiting */ protected void cleanupIO() { try { // kill the svm_classify_std process if (m_procWriter != null) { m_procWriter.close(); } if (m_procReader != null) { m_procReader.close(); } m_procReader = null; m_procWriter = null; // delete the model file if (!m_debug && (m_modelFilename != null)) { File modelFile = new File(m_modelFilename); modelFile.delete(); } } catch (Exception e) { System.out.println("Problems when cleaning up IO"); e.printStackTrace(); } } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (instances.classIndex() < 0) { throw new Exception ("No class attribute assigned to instances."); } if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes."); } int numClasses = instances.numClasses(); if (numClasses != 2) { throw new Exception("Training data should have two classes; has " + numClasses + " classes"); } // if a classifier has been built, clean up the IO if (m_bufferedMode && m_procWriter != null) { cleanupIO(); } // create a working copy of training data m_tempDirFile = new File(m_tempDirPath); m_train = new Instances(instances, 0, instances.numInstances()); // Unlike most Weka classifiers, we are *not* throwing away training // instances with missing class, since they may be used for transduction. // If it is desired to avoid transduction and throw out unlabeled data, // uncomment the following line: // m_train.deleteWithMissingClass(); // Convert training instances into SVMlight format and dump into a training file dumpTrainingData(m_train); // Train the model trainSVMlight(); // set min and max margin if desired if (m_autoBounds) { setBounds(instances); } } /** Set the bounds using "extreme" training examples - TODO!*/ protected void setBounds(Instances data) { try { // get the minimum margin double[] values = new double[data.numAttributes()]; Instance zeroInstance = new Instance(1.0, values); zeroInstance.setDataset(data); if (!m_bufferedMode) { File testFile = File.createTempFile(m_testFilenameBase, ".dat", m_tempDirFile); if (!m_debug) { testFile.deleteOnExit(); } dumpInstance(zeroInstance, testFile); } double minMargin = classifySVMlight(zeroInstance); setMinMargin(minMargin); // get the maximum margin double maxMargin = 0; for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); // we only care about positive examples if (instance.classValue() == 0) { if (!m_bufferedMode) { File testFile = File.createTempFile(m_testFilenameBase, ".dat", m_tempDirFile); if (!m_debug) { testFile.deleteOnExit(); } dumpInstance(zeroInstance, testFile); } double margin = classifySVMlight(instance); if (margin < maxMargin) { maxMargin = margin; } } setMaxMargin(maxMargin); } System.out.println("xxxxx MINMARGIN=" + minMargin + "\tMAX_MARGIN=" + maxMargin); } catch (Exception e) { System.err.println("Problems obtaining automatic margins: " + e); e.printStackTrace(); } } /** * Dump training instances into a file in SVM-light format * @param instances the training instances * @param filename name of the file where instance will be dumped */ protected void dumpTrainingData(Instances instances) { try { File trainFile = File.createTempFile(m_trainFilenameBase, ".dat", m_tempDirFile); if (!m_debug) { trainFile.deleteOnExit(); } m_trainFilename = trainFile.getPath(); PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(trainFile))); int classIdx = instances.classIndex(); // Go through all instances Enumeration enum = instances.enumerateInstances(); while (enum.hasMoreElements()) { Instance trainInstance = (Instance) enum.nextElement(); // output the class value double classValue = 0; if (!trainInstance.classIsMissing()) { classValue = trainInstance.classValue(); if (classValue == 0) { classValue = -1; } else { classValue = 1; } } writer.print((int)classValue + " "); // output the attributes; iterating using numValues() skips 'missing' values for SparseInstances for (int j = 0; j < trainInstance.numValues(); j++) { Attribute attribute = trainInstance.attributeSparse(j); // Attribute index must be greater than 0 int attrIdx = attribute.index(); if (attrIdx != classIdx) { writer.print((attrIdx+1) + ":" + trainInstance.value(attrIdx) + " "); } } writer.println(); } writer.close(); } catch (Exception e) { System.err.println("Error when dumping training instances: " + e); e.printStackTrace(); } } /** * Dump a single instance into a file in SVM-light format * @param instance an instance * @param file the file where instance will be dumped */ protected void dumpInstance(Instance instance, File file) { try { PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(file))); // output a dummy class value int classIdx = instance.classIndex(); writer.print(Integer.MAX_VALUE + " "); // output the attributes; iterating using numValues skips 'missing' values for SparseInstances for (int j = 0; j < instance.numValues(); j++) { Attribute attribute = instance.attributeSparse(j); int attrIdx = attribute.index(); if (attrIdx != classIdx) { writer.print((attrIdx+1) + ":" + instance.value(attrIdx) + " "); } } writer.println(); writer.close(); } catch (Exception e) { System.err.println("Error when dumping instance: " + e); e.printStackTrace(); } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -