📄 smo.java
字号:
L = Math.max(0, alph1 + alph2 - C1); H = Math.min(C2, alph1 + alph2); } if (L >= H) { return false; } // Compute second derivative of objective function k11 = m_kernel.eval(i1, i1, m_data.instance(i1)); k12 = m_kernel.eval(i1, i2, m_data.instance(i1)); k22 = m_kernel.eval(i2, i2, m_data.instance(i2)); eta = 2 * k12 - k11 - k22; // Check if second derivative is negative if (eta < 0) { // Compute unconstrained maximum a2 = alph2 - y2 * (F1 - F2) / eta; // Compute constrained maximum if (a2 < L) { a2 = L; } else if (a2 > H) { a2 = H; } } else { // Look at endpoints of diagonal f1 = SVMOutput(i1, m_data.instance(i1)); f2 = SVMOutput(i2, m_data.instance(i2)); v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12; v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22; double gamma = alph1 + s * alph2; Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) - 0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L - y1 * (gamma - s * L) * v1 - y2 * L * v2; Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) - 0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H - y1 * (gamma - s * H) * v1 - y2 * H * v2; if (Lobj > Hobj + m_eps) { a2 = L; } else if (Lobj < Hobj - m_eps) { a2 = H; } else { a2 = alph2; } } if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) { return false; } // To prevent precision problems if (a2 > C2 - m_Del * C2) { a2 = C2; } else if (a2 <= m_Del * C2) { a2 = 0; } // Recompute a1 a1 = alph1 + s * (alph2 - a2); // To prevent precision problems if (a1 > C1 - m_Del * C1) { a1 = C1; } else if (a1 <= m_Del * C1) { a1 = 0; } // Update sets if (a1 > 0) { m_supportVectors.insert(i1); } else { m_supportVectors.delete(i1); } if ((a1 > 0) && (a1 < C1)) { m_I0.insert(i1); } else { m_I0.delete(i1); } if ((y1 == 1) && (a1 == 0)) { m_I1.insert(i1); } else { m_I1.delete(i1); } if ((y1 == -1) && (a1 == C1)) { m_I2.insert(i1); } else { m_I2.delete(i1); } if ((y1 == 1) && (a1 == C1)) { m_I3.insert(i1); } else { m_I3.delete(i1); } if ((y1 == -1) && (a1 == 0)) { m_I4.insert(i1); } else { m_I4.delete(i1); } if (a2 > 0) { m_supportVectors.insert(i2); } else { m_supportVectors.delete(i2); } if ((a2 > 0) && (a2 < C2)) { m_I0.insert(i2); } else { m_I0.delete(i2); } if ((y2 == 1) && (a2 == 0)) { m_I1.insert(i2); } else { m_I1.delete(i2); } if ((y2 == -1) && (a2 == C2)) { m_I2.insert(i2); } else { m_I2.delete(i2); } if ((y2 == 1) && (a2 == C2)) { m_I3.insert(i2); } else { m_I3.delete(i2); } if ((y2 == -1) && (a2 == 0)) { m_I4.insert(i2); } else { m_I4.delete(i2); } // Update weight vector to reflect change a1 and a2, if linear SVM if (m_KernelIsLinear) { Instance inst1 = m_data.instance(i1); for (int p1 = 0; p1 < inst1.numValues(); p1++) { if (inst1.index(p1) != m_data.classIndex()) { m_weights[inst1.index(p1)] += y1 * (a1 - alph1) * inst1.valueSparse(p1); } } Instance inst2 = m_data.instance(i2); for (int p2 = 0; p2 < inst2.numValues(); p2++) { if (inst2.index(p2) != m_data.classIndex()) { m_weights[inst2.index(p2)] += y2 * (a2 - alph2) * inst2.valueSparse(p2); } } } // Update error cache using new Lagrange multipliers for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) { if ((j != i1) && (j != i2)) { m_errors[j] += y1 * (a1 - alph1) * m_kernel.eval(i1, j, m_data.instance(i1)) + y2 * (a2 - alph2) * m_kernel.eval(i2, j, m_data.instance(i2)); } } // Update error cache for i1 and i2 m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12; m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22; // Update array with Lagrange multipliers m_alpha[i1] = a1; m_alpha[i2] = a2; // Update thresholds m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE; m_iLow = -1; m_iUp = -1; for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) { if (m_errors[j] < m_bUp) { m_bUp = m_errors[j]; m_iUp = j; } if (m_errors[j] > m_bLow) { m_bLow = m_errors[j]; m_iLow = j; } } if (!m_I0.contains(i1)) { if (m_I3.contains(i1) || m_I4.contains(i1)) { if (m_errors[i1] > m_bLow) { m_bLow = m_errors[i1]; m_iLow = i1; } } else { if (m_errors[i1] < m_bUp) { m_bUp = m_errors[i1]; m_iUp = i1; } } } if (!m_I0.contains(i2)) { if (m_I3.contains(i2) || m_I4.contains(i2)) { if (m_errors[i2] > m_bLow) { m_bLow = m_errors[i2]; m_iLow = i2; } } else { if (m_errors[i2] < m_bUp) { m_bUp = m_errors[i2]; m_iUp = i2; } } } if ((m_iLow == -1) || (m_iUp == -1)) { throw new Exception("This should never happen!"); } // Made some progress. return true; } /** * Quick and dirty check whether the quadratic programming problem is solved. * * @throws Exception if checking fails */ protected void checkClassifier() throws Exception { double sum = 0; for (int i = 0; i < m_alpha.length; i++) { if (m_alpha[i] > 0) { sum += m_class[i] * m_alpha[i]; } } System.err.println("Sum of y(i) * alpha(i): " + sum); for (int i = 0; i < m_alpha.length; i++) { double output = SVMOutput(i, m_data.instance(i)); if (Utils.eq(m_alpha[i], 0)) { if (Utils.sm(m_class[i] * output, 1)) { System.err.println("KKT condition 1 violated: " + m_class[i] * output); } } if (Utils.gr(m_alpha[i], 0) && Utils.sm(m_alpha[i], m_C * m_data.instance(i).weight())) { if (!Utils.eq(m_class[i] * output, 1)) { System.err.println("KKT condition 2 violated: " + m_class[i] * output); } } if (Utils.eq(m_alpha[i], m_C * m_data.instance(i).weight())) { if (Utils.gr(m_class[i] * output, 1)) { System.err.println("KKT condition 3 violated: " + m_class[i] * output); } } } } } /** filter: Normalize training data */ public static final int FILTER_NORMALIZE = 0; /** filter: Standardize training data */ public static final int FILTER_STANDARDIZE = 1; /** filter: No normalization/standardization */ public static final int FILTER_NONE = 2; /** The filter to apply to the training data */ public static final Tag [] TAGS_FILTER = { new Tag(FILTER_NORMALIZE, "Normalize training data"), new Tag(FILTER_STANDARDIZE, "Standardize training data"), new Tag(FILTER_NONE, "No normalization/standardization"), }; /** The binary classifier(s) */ protected BinarySMO[][] m_classifiers = null; /** The complexity parameter. */ protected double m_C = 1.0; /** Epsilon for rounding. */ protected double m_eps = 1.0e-12; /** Tolerance for accuracy of result. */ protected double m_tol = 1.0e-3; /** Whether to normalize/standardize/neither */ protected int m_filterType = FILTER_NORMALIZE; /** The filter used to make attributes numeric. */ protected NominalToBinary m_NominalToBinary; /** The filter used to standardize/normalize all values. */ protected Filter m_Filter = null; /** The filter used to get rid of missing values. */ protected ReplaceMissingValues m_Missing; /** The class index from the training data */ protected int m_classIndex = -1; /** The class attribute */ protected Attribute m_classAttribute; /** whether the kernel is a linear one */ protected boolean m_KernelIsLinear = false; /** Turn off all checks and conversions? Turning them off assumes that data is purely numeric, doesn't contain any missing values, and has a nominal class. Turning them off also means that no header information will be stored if the machine is linear. Finally, it also assumes that no instance has a weight equal to 0.*/ protected boolean m_checksTurnedOff; /** Precision constant for updating sets */ protected static double m_Del = 1000 * Double.MIN_VALUE; /** Whether logistic models are to be fit */ protected boolean m_fitLogisticModels = false; /** The number of folds for the internal cross-validation */ protected int m_numFolds = -1; /** The random number seed */ protected int m_randomSeed = 1; /** the kernel to use */ protected Kernel m_kernel = new PolyKernel(); /** * Turns off checks for missing values, etc. Use with caution. */ public void turnChecksOff() { m_checksTurnedOff = true; } /** * Turns on checks for missing values, etc. */ public void turnChecksOn() { m_checksTurnedOff = false; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = getKernel().getCapabilities(); result.setOwner(this); // attribute result.enableAllAttributeDependencies(); // with NominalToBinary we can also handle nominal attributes, but only // if the kernel can handle numeric attributes if (result.handles(Capability.NUMERIC_ATTRIBUTES)) result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Method for building the classifier. Implements a one-against-one * wrapper for multi-class problems. * * @param insts the set of training instances * @throws Exception if the classifier can't be built successfully */ public void buildClassifier(Instances insts) throws Exception { if (!m_checksTurnedOff) { // can classifier handle the data? getCapabilities().testWithFail(insts); // remove instances with missing class insts = new Instances(insts); insts.deleteWithMissingClass(); /* Removes all the instances with weight equal to 0. MUST be done since condition (8) of Keerthi's paper is made with the assertion Ci > 0 (See equation (3a). */ Instances data = new Instances(insts, insts.numInstances()); for(int i = 0; i < insts.numInstances(); i++){ if(insts.instance(i).weight() > 0) data.add(insts.instance(i)); } if (data.numInstances() == 0) { throw new Exception("No training instances left after removing " + "instances with weight 0!"); } insts = data; } if (!m_checksTurnedOff) { m_Missing = new ReplaceMissingValues(); m_Missing.setInputFormat(insts); insts = Filter.useFilter(insts, m_Missing); } else { m_Missing = null; } if (getCapabilities().handles(Capability.NUMERIC_ATTRIBUTES)) { boolean onlyNumeric = true; if (!m_checksTurnedOff) { for (int i = 0; i < insts.numAttributes(); i++) { if (i != insts.classIndex()) { if (!insts.attribute(i).isNumeric()) { onlyNumeric = false; break; } } } } if (!onlyNumeric) { m_NominalToBinary = new NominalToBinary(); m_NominalToBinary.setInputFormat(insts); insts = Filter.useFilter(insts, m_NominalToBinary); } else { m_NominalToBinary = null; } } else { m_NominalToBinary = null; } if (m_filterType == FILTER_STANDARDIZE) { m_Filter = new Standardize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } else if (m_filterType == FILTER_NORMALIZE) { m_Filter = new Normalize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } else { m_Filter = null; } m_classIndex = insts.classIndex(); m_classAttribute = insts.classAttribute(); m_KernelIsLinear = (m_kernel instanceof PolyKernel) && (((PolyKernel) m_kernel).getExponent() == 1.0); // Generate subsets representing each class Instances[] subsets = new Instances[insts.numClasses()];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -