📄 smo.java
字号:
} else if (numChanged == 0) { examineAll = true; } } // Set threshold m_b = (m_bLow + m_bUp) / 2.0; // Save memory m_storage = null; m_keys = null; m_errors = null; m_I0 = m_I1 = m_I2 = m_I3 = m_I4 = null; // If machine is linear, delete training data if (m_exponent == 1.0) { m_data = new Instances(m_data, 0); } } /** * Computes SVM output for given instance. * * @param index the instance for which output is to be computed * @param inst the instance * @return the output of the SVM for the given instance */ private double SVMOutput(int index, Instance inst) throws Exception { double result = 0; // Is the machine linear? if (m_exponent == 1.0) { int n1 = inst.numValues(); int classIndex = m_data.classIndex(); for (int p = 0; p < n1; p++) { if (inst.index(p) != classIndex) { result += m_weights[inst.index(p)] * inst.valueSparse(p); } } } else { for (int i = m_supportVectors.getNext(-1); i != -1; i = m_supportVectors.getNext(i)) { result += m_class[i] * m_alpha[i] * kernel(index, i, inst); } } result -= m_b; return result; } /** * Outputs the distribution for the given output. * * Pipes output of SVM through sigmoid function. * @param inst the instance for which distribution is to be computed * @return the distribution * @exception Exception if something goes wrong */ public double[] distributionForInstance(Instance inst) throws Exception { // Filter instance m_Missing.input(inst); m_Missing.batchFinished(); inst = m_Missing.output(); if (m_Normalize) { m_Normalization.input(inst); m_Normalization.batchFinished(); inst = m_Normalization.output(); } if (!m_onlyNumeric) { m_NominalToBinary.input(inst); m_NominalToBinary.batchFinished(); inst = m_NominalToBinary.output(); } // Get probabilities double output = SVMOutput(-1, inst); double[] result = new double[2]; result[1] = 1.0 / (1.0 + Math.exp(-output)); result[0] = 1.0 - result[1]; return result; } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(8); newVector.addElement(new Option("\tThe complexity constant C. (default 1)", "C", 1, "-C <double>")); newVector.addElement(new Option("\tThe exponent for the " + "polynomial kernel. (default 1)", "E", 1, "-E <double>")); newVector.addElement(new Option("\tDon't normalize the data.", "N", 0, "-N")); newVector.addElement(new Option("\tRescale the kernel.", "L", 0, "-L")); newVector.addElement(new Option("\tUse lower-order terms.", "O", 0, "-O")); newVector.addElement(new Option("\tThe size of the kernel cache. " + "(default 1000003)", "A", 1, "-A <int>")); newVector.addElement(new Option("\tThe tolerance parameter. " + "(default 1.0e-3)", "T", 1, "-T <double>")); newVector.addElement(new Option("\tThe epsilon for round-off error. " + "(default 1.0e-12)", "P", 1, "-P <double>")); return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -C num <br> * The complexity constant C. (default 1)<p> * * -E num <br> * The exponent for the polynomial kernel. (default 1) <p> * * -N <br> * Don't normalize the training instances. <p> * * -L <br> * Rescale kernel. <p> * * -O <br> * Use lower-order terms. <p> * * -A num <br> * Sets the size of the kernel cache. Should be a prime number. (default 1000003) <p> * * -T num <br> * Sets the tolerance parameter. (default 1.0e-3)<p> * * -P num <br> * Sets the epsilon for round-off error. (default 1.0e-12)<p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String complexityString = Utils.getOption('C', options); if (complexityString.length() != 0) { m_C = (new Double(complexityString)).doubleValue(); } else { m_C = 1.0; } String exponentsString = Utils.getOption('E', options); if (exponentsString.length() != 0) { m_exponent = (new Double(exponentsString)).doubleValue(); } else { m_exponent = 1.0; } String cacheString = Utils.getOption('A', options); if (cacheString.length() != 0) { m_cacheSize = Integer.parseInt(cacheString); } else { m_cacheSize = 1000003; } String toleranceString = Utils.getOption('T', options); if (toleranceString.length() != 0) { m_tol = (new Double(toleranceString)).doubleValue(); } else { m_tol = 1.0e-3; } String epsilonString = Utils.getOption('P', options); if (epsilonString.length() != 0) { m_eps = (new Double(epsilonString)).doubleValue(); } else { m_eps = 1.0e-12; } m_Normalize = !Utils.getFlag('N', options); m_rescale = Utils.getFlag('L', options); if ((m_exponent == 1.0) && (m_rescale)) { throw new Exception("Can't use rescaling with linear machine."); } m_lowerOrder = Utils.getFlag('O', options); if ((m_exponent == 1.0) && (m_lowerOrder)) { throw new Exception("Can't use lower-order terms with linear machine."); } } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] options = new String [13]; int current = 0; options[current++] = "-C"; options[current++] = "" + m_C; options[current++] = "-E"; options[current++] = "" + m_exponent; options[current++] = "-A"; options[current++] = "" + m_cacheSize; options[current++] = "-T"; options[current++] = "" + m_tol; options[current++] = "-P"; options[current++] = "" + m_eps; if (!m_Normalize) { options[current++] = "-N"; } if (m_rescale) { options[current++] = "-L"; } if (m_lowerOrder) { options[current++] = "-O"; } while (current < options.length) { options[current++] = ""; } return options; } /** * Prints out the classifier. * * @return a description of the classifier as a string */ public String toString() { StringBuffer text = new StringBuffer(); int printed = 0; if (m_alpha == null) { return "SMO: No model built yet."; } try { text.append("SMO\n\n"); // If machine linear, print weight vector if (m_exponent == 1.0) { text.append("Machine linear: showing attribute weights, "); text.append("not support vectors.\n\n"); for (int i = 0; i < m_weights.length; i++) { if (i != (int)m_data.classIndex()) { if (printed > 0) { text.append(" + "); } else { text.append(" "); } text.append(m_weights[i] + " * " + m_data.attribute(i).name()+"\n"); printed++; } } } else { for (int i = 0; i < m_alpha.length; i++) { if (m_supportVectors.contains(i)) { if (printed > 0) { text.append(" + "); } else { text.append(" "); } text.append(((int)m_class[i]) + " * " + m_alpha[i] + " * K[X(" + i + ") * X]\n"); printed++; } } } text.append(" - " + m_b); text.append("\n\nNumber of support vectors: " + m_supportVectors.numElements()); text.append("\n\nNumber of kernel evaluations: " + m_kernelEvals); } catch (Exception e) { return "Can't print SMO classifier."; } return text.toString(); } /** * Get the value of exponent. * * @return Value of exponent. */ public double getExponent() { return m_exponent; } /** * Set the value of exponent. If linear kernel * is used, rescaling and lower-order terms are * turned off. * * @param v Value to assign to exponent. */ public void setExponent(double v) { if (v == 1.0) { m_rescale = false; m_lowerOrder = false; } m_exponent = v; } /** * Get the value of C. * * @return Value of C. */ public double getC() { return m_C; } /** * Set the value of C. * * @param v Value to assign to C. */ public void setC(double v) { m_C = v; } /** * Get the value of tolerance parameter. * @return Value of tolerance parameter. */ public double getToleranceParameter() { return m_tol; } /** * Set the value of tolerance parameter. * @param v Value to assign to tolerance parameter. */ public void setToleranceParameter(double v) { m_tol = v; } /** * Get the value of epsilon. * @return Value of epsilon. */ public double getEpsilon() { return m_eps; } /** * Set the value of epsilon. * @param v Value to assign to epsilon. */ public void setEpsilon(double v) { m_eps = v; } /** * Get the size of the kernel cache * @return Size of kernel cache. */ public int getCacheSize() { return m_cacheSize; } /** * Set the value of the kernel cache. * @param v Size of kernel cache. */ public void setCacheSize(int v) { m_cacheSize = v; } /** * Check whether data is to be normalized. * @return true if data is to be normalized */ public boolean getNormalizeData() { return m_Normalize; } /** * Set whether data is to be normalized. * @param v true if data is to be normalized */ public void setNormalizeData(boolean v) { m_Normalize = v; } /** * Check whether kernel is being rescaled. * @return Value of rescale. */ public boolean getRescaleKernel() throws Exception { return m_rescale; } /** * Set whether kernel is to be rescaled. Defaults * to false if a linear machine is built. * @param v Value to assign to rescale. */ public void setRescaleKernel(boolean v) throws Exception { if (m_exponent == 1.0) { m_rescale = false; } else { m_rescale = v; } } /** * Check whether lower-order terms are being used. * @return Value of lowerOrder. */ public boolean getLowerOrderTerms() { return m_lowerOrder; } /** * Set whether lower-order terms are to be used. Defaults * to false if a linear machine is built.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -