📄 svmlight.java
字号:
/** Launch an SVM-light process assuming that the training data has been dumped */ protected void trainSVMlight() { try { String command = new String(m_binPath + "svm_learn"); // append all options for SVM-light command = command + " -v " + m_verbosityLevel; switch (m_mode) { case SVM_MODE_CLASSIFICATION: command = command + " -z c"; break; case SVM_MODE_REGRESSION: command = command + " -z r"; command = command + " -w " + m_width; break; case SVM_MODE_PREFERENCE_RANKING: command = command + " -z p"; break; default: throw new Exception("Unknown mode: " + m_mode); } command = command + " -c " + m_C; command = command + " -j " + m_costFactor; command = command + " -b " + (m_biased ? 1 : 0); command = command + " -i " + (m_removeInconsistentExamples ? 1 : 0); switch(m_kernelType) { case KERNEL_LINEAR: command = command + " -t 0"; break; case KERNEL_POLYNOMIAL: command = command + " -t 1"; command = command + " -d " + m_d; command = command + " -s " + m_s; command = command + " -r " + m_c1; break; case KERNEL_RBF: command = command + " -t 2"; command = command + " -g " + m_gamma; break; case KERNEL_SIGMOID_TANH: command = command + " -t 3"; command = command + " -s " + m_s; command = command + " -r " + m_c1; break; default: throw new Exception("Unknown kernel type: " + m_kernelType); } // create the model file File modelFile = File.createTempFile(m_modelFilenameBase, ".dat", m_tempDirFile); if (!m_debug) { modelFile.deleteOnExit(); } m_modelFilename = modelFile.getPath(); command = command + " " + m_trainFilename + " " + m_modelFilename; if (m_debug) { System.out.println("Executing SVMlight: \n\t" + command); } Process proc = Runtime.getRuntime().exec(command); // read the training output if (proc != null){ BufferedReader procOutput = new BufferedReader(new InputStreamReader(proc.getInputStream())); try { String line; while ((line = procOutput.readLine()) != null){ if (m_debug) { System.out.println("SVM: " + line); } } } catch (Exception e) { System.err.println("Problems trapping output in debug mode:"); e.printStackTrace(); System.out.println(e); } } int exitValue = proc.waitFor(); if (exitValue != 0) { throw new Exception("Problems training SVM-light: process returned value " + exitValue); } // delete the training file File trainFile = new File(m_trainFilename); // trainFile.delete(); m_svmTrained = true; } catch (Exception e) { System.out.println("Problem training: "); e.printStackTrace(); System.err.println(e); } } /** Launch an SVM-light process and classify a given instance * @param instance an instance that must be classified */ protected double classifySVMlight(Instance instance) { double prediction = Double.MIN_VALUE; String lineIn = null; StringBuffer instanceString = new StringBuffer(); try { if (m_bufferedMode) { // if this is the first time classify() is called, initialize the classifier process if (m_procWriter == null) { String command = new String(m_binPath + "svm_classify_std -v " + m_verbosityLevel + " " + m_modelFilename); if (m_debug) { System.out.println("Executing \"" + command + "\""); } Process proc = Runtime.getRuntime().exec(command); m_procReader = new BufferedReader(new InputStreamReader(proc.getInputStream())); m_procWriter = new BufferedWriter(new OutputStreamWriter(proc.getOutputStream())); System.out.println(m_procReader.readLine()); System.out.println(m_procReader.readLine()); } // pass the instance to SVMlight process // output a bogus class value instanceString.append(Integer.MAX_VALUE + " "); // output the attributes; iterating using numValues skips 'missing' values for SparseInstances int classIdx = instance.classIndex(); for (int j = 0; j < instance.numValues(); j++) { Attribute attribute = instance.attributeSparse(j); int attrIdx = attribute.index(); if (attrIdx != classIdx) { instanceString.append((attrIdx+1) + ":" + instance.value(attrIdx) + " "); } } instanceString.append("\n"); if (m_debug) { System.out.println("Sending " + instance); System.out.flush(); } m_procWriter.write(instanceString.toString()); m_procWriter.flush(); lineIn = m_procReader.readLine(); if (lineIn == null) { throw new Exception("Got null prediction from SVMlight!"); } prediction = Double.parseDouble(lineIn); if (m_debug) { System.out.println("Got " + prediction); } } else { // Non-buffered IO, a temporary test file is used for the test instance // create a temporary file where the test instance is dumped File testFile = File.createTempFile(m_testFilenameBase, ".dat", m_tempDirFile); if (!m_debug) { testFile.deleteOnExit(); } m_testFilename = testFile.getPath(); dumpInstance(instance, testFile); // create a temporary file where the SVMlight output (prediction) will be stored File predictionFile = File.createTempFile(m_predictionFilenameBase, ".dat", m_tempDirFile); if (!m_debug) { predictionFile.deleteOnExit(); } m_predictionFilename = predictionFile.getPath(); // run svm_classify String command = new String(m_binPath + "svm_classify -v " + m_verbosityLevel + " " + m_testFilename + " " + m_modelFilename + " " + m_predictionFilename ); Process proc = Runtime.getRuntime().exec(command); int exitValue = proc.waitFor(); if (exitValue != 0) { throw new Exception("Problems running SVM-light: process returned value " + exitValue); } prediction = readPrediction(predictionFile); testFile.delete(); predictionFile.delete(); } } catch (Exception e) { System.out.println("Got from SVM-light: " + lineIn); System.err.println(e); e.printStackTrace(); } return prediction; } /** Read the prediction of SVM-light * @param file file where the prediction is stored */ protected double readPrediction(File file) { double result = Double.MIN_VALUE; try { BufferedReader r = new BufferedReader(new FileReader(file)); String line = r.readLine(); if (line == null) { throw new Exception("Empty prediction file " + file.getPath()); } result = Double.parseDouble(line); } catch (Exception e) { System.err.println("Error reading the prediction file: " + e); } return result; } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if an error occurred during the prediction */ public double [] distributionForInstance(Instance instance) throws Exception{ if (!m_svmTrained) { throw new Exception("SVM has not been trained!"); } // compute prediction double margin = classifySVMlight(instance); double[] predictions = new double[2]; predictions[0] = 1 - (margin - m_maxMargin)/(m_minMargin - m_maxMargin); if (predictions[0] > 1) { // System.out.println("overflow: " + predictions[0]); predictions[0] = 1; } if (predictions[0] < 0) { // System.out.println("underflow: " + predictions[0]); predictions[0] = 0; } predictions[1] = 1- predictions[0]; if (m_debug) { System.out.println("\t\tMargin: " + margin + "\tDistribution: {" + predictions[0] + ",\t" + predictions[1] + "}"); } return predictions; } /** Check whether the SVM has been trained * @return true if the SVM has been train and is ready to classify instances */ public boolean trained() { return m_svmTrained; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(2); newVector.addElement(new Option( "\tOutput debug information", "D", 0, "-D")); return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -D <br> * output debugging information <p> * * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { setDebug(Utils.getFlag('D', options)); String verbosityString = Utils.getOption('v', options); if (verbosityString.length() != 0) { setVerbosityLevel(Integer.parseInt(verbosityString)); } if (Utils.getFlag('A', options)) { setAutoBounds(true); } else { String minMarginString = Utils.getOption('n', options); if (minMarginString.length() != 0) { setMinMargin(Double.parseDouble(minMarginString)); } String maxMarginString = Utils.getOption('m', options); if (maxMarginString.length() != 0) { setMaxMargin(Double.parseDouble(maxMarginString)); } } if (Utils.getFlag('C', options)) { setMode(new SelectedTag(SVM_MODE_CLASSIFICATION, TAGS_SVM_MODE)); } else if (Utils.getFlag('R', options)) { setMode(new SelectedTag(SVM_MODE_REGRESSION, TAGS_SVM_MODE)); String widthString = Utils.getOption('w', options); if (widthString.length() != 0) { setWidth(Double.parseDouble(widthString)); } } else if (Utils.getFlag('P', options)) { setMode(new SelectedTag(SVM_MODE_PREFERENCE_RANKING, TAGS_SVM_MODE)); } String cString = Utils.getOption('c', options); if (cString.length() != 0) { setC(Double.parseDouble(cString)); } String costFactorString = Utils.getOption('j', options); if (costFactorString.length() != 0) { setCostFactor(Double.parseDouble(costFactorString)); } setBiased(Utils.getFlag('b', options)); setRemoveInconsistentExamples(Utils.getFlag('i', options)); // kernel-type related options if (Utils.getFlag('L', options)) { setKernelType(new SelectedTag(KERNEL_LINEAR, TAGS_KERNEL_TYPE)); } else if (Utils.getFlag('O', options)) { setKernelType(new SelectedTag(KERNEL_POLYNOMIAL, TAGS_KERNEL_TYPE)); String dString = Utils.getOption('d', options); if (dString.length() != 0) { setD(Integer.parseInt(dString)); } String sString = Utils.getOption('s', options); if (sString.length() != 0) { setS(Double.parseDouble(sString)); } String c1String = Utils.getOption('r', options); if (c1String.length() != 0) { setC1(Double.parseDouble(c1String)); } } else if (Utils.getFlag('B', options)) { setKernelType(new SelectedTag(KERNEL_RBF, TAGS_KERNEL_TYPE)); String gammaString = Utils.getOption('g', options); if (gammaString.length() != 0) { setC1(Double.parseDouble(gammaString)); } } else if (Utils.getFlag('S', options)) { setKernelType(new SelectedTag(KERNEL_SIGMOID_TANH, TAGS_KERNEL_TYPE)); String sString = Utils.getOption('s', options); if (sString.length() != 0) { setS(Double.parseDouble(sString)); } String c1String = Utils.getOption('r', options); if (c1String.length() != 0) { setC1(Double.parseDouble(c1String)); } } String binPathString = Utils.getOption('p', options); if (binPathString.length() != 0) { setBinPath(binPathString); } Utils.checkForRemainingOptions(options);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -