📄 smo.java
字号:
} else { m_I3.delete(i1); } if ((y1 == -1) && (a1 == 0)) { m_I4.insert(i1); } else { m_I4.delete(i1); } if (a2 > 0) { m_supportVectors.insert(i2); } else { m_supportVectors.delete(i2); } if ((a2 > 0) && (a2 < C2)) { m_I0.insert(i2); } else { m_I0.delete(i2); } if ((y2 == 1) && (a2 == 0)) { m_I1.insert(i2); } else { m_I1.delete(i2); } if ((y2 == -1) && (a2 == C2)) { m_I2.insert(i2); } else { m_I2.delete(i2); } if ((y2 == 1) && (a2 == C2)) { m_I3.insert(i2); } else { m_I3.delete(i2); } if ((y2 == -1) && (a2 == 0)) { m_I4.insert(i2); } else { m_I4.delete(i2); } // Update weight vector to reflect change a1 and a2, if linear SVM if (!m_useRBF && m_exponent == 1.0) { Instance inst1 = m_data.instance(i1); for (int p1 = 0; p1 < inst1.numValues(); p1++) { if (inst1.index(p1) != m_data.classIndex()) { m_weights[inst1.index(p1)] += y1 * (a1 - alph1) * inst1.valueSparse(p1); } } Instance inst2 = m_data.instance(i2); for (int p2 = 0; p2 < inst2.numValues(); p2++) { if (inst2.index(p2) != m_data.classIndex()) { m_weights[inst2.index(p2)] += y2 * (a2 - alph2) * inst2.valueSparse(p2); } } } // Update error cache using new Lagrange multipliers for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) { if ((j != i1) && (j != i2)) { m_errors[j] += y1 * (a1 - alph1) * m_kernel.eval(i1, j, m_data.instance(i1)) + y2 * (a2 - alph2) * m_kernel.eval(i2, j, m_data.instance(i2)); } } // Update error cache for i1 and i2 m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12; m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22; // Update array with Lagrange multipliers m_alpha[i1] = a1; m_alpha[i2] = a2; // Update thresholds m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE; m_iLow = -1; m_iUp = -1; for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) { if (m_errors[j] < m_bUp) { m_bUp = m_errors[j]; m_iUp = j; } if (m_errors[j] > m_bLow) { m_bLow = m_errors[j]; m_iLow = j; } } if (!m_I0.contains(i1)) { if (m_I3.contains(i1) || m_I4.contains(i1)) { if (m_errors[i1] > m_bLow) { m_bLow = m_errors[i1]; m_iLow = i1; } } else { if (m_errors[i1] < m_bUp) { m_bUp = m_errors[i1]; m_iUp = i1; } } } if (!m_I0.contains(i2)) { if (m_I3.contains(i2) || m_I4.contains(i2)) { if (m_errors[i2] > m_bLow) { m_bLow = m_errors[i2]; m_iLow = i2; } } else { if (m_errors[i2] < m_bUp) { m_bUp = m_errors[i2]; m_iUp = i2; } } } if ((m_iLow == -1) || (m_iUp == -1)) { throw new Exception("This should never happen!"); } // Made some progress. return true; } /** * Quick and dirty check whether the quadratic programming problem is solved. */ private void checkClassifier() throws Exception { double sum = 0; for (int i = 0; i < m_alpha.length; i++) { if (m_alpha[i] > 0) { sum += m_class[i] * m_alpha[i]; } } System.err.println("Sum of y(i) * alpha(i): " + sum); for (int i = 0; i < m_alpha.length; i++) { double output = SVMOutput(i, m_data.instance(i)); if (Utils.eq(m_alpha[i], 0)) { if (Utils.sm(m_class[i] * output, 1)) { System.err.println("KKT condition 1 violated: " + m_class[i] * output); } } if (Utils.gr(m_alpha[i], 0) && Utils.sm(m_alpha[i], m_C * m_data.instance(i).weight())) { if (!Utils.eq(m_class[i] * output, 1)) { System.err.println("KKT condition 2 violated: " + m_class[i] * output); } } if (Utils.eq(m_alpha[i], m_C * m_data.instance(i).weight())) { if (Utils.gr(m_class[i] * output, 1)) { System.err.println("KKT condition 3 violated: " + m_class[i] * output); } } } } } /** The filter to apply to the training data */ public static final int FILTER_NORMALIZE = 0; public static final int FILTER_STANDARDIZE = 1; public static final int FILTER_NONE = 2; public static final Tag [] TAGS_FILTER = { new Tag(FILTER_NORMALIZE, "Normalize training data"), new Tag(FILTER_STANDARDIZE, "Standardize training data"), new Tag(FILTER_NONE, "No normalization/standardization"), }; /** The binary classifier(s) */ private BinarySMO[][] m_classifiers = null; /** The exponent for the polynomial kernel. */ private double m_exponent = 1.0; /** Gamma for the RBF kernel. */ private double m_gamma = 0.01; /** The complexity parameter. */ private double m_C = 1.0; /** Epsilon for rounding. */ private double m_eps = 1.0e-12; /** Tolerance for accuracy of result. */ private double m_tol = 1.0e-3; /** Whether to normalize/standardize/neither */ private int m_filterType = FILTER_NORMALIZE; /** Feature-space normalization? */ private boolean m_featureSpaceNormalization = false; /** Use lower-order terms? */ private boolean m_lowerOrder = false; /** Use RBF kernel? (default: poly) */ private boolean m_useRBF = false; /** The size of the cache (a prime number) */ private int m_cacheSize = 1000003; /** The filter used to make attributes numeric. */ private NominalToBinary m_NominalToBinary; /** The filter used to standardize/normalize all values. */ private Filter m_Filter = null; /** The filter used to get rid of missing values. */ private ReplaceMissingValues m_Missing; /** Only numeric attributes in the dataset? */ private boolean m_onlyNumeric; /** The class index from the training data */ private int m_classIndex = -1; /** The class attribute */ private Attribute m_classAttribute; /** Turn off all checks and conversions? Turning them off assumes that data is purely numeric, doesn't contain any missing values, and has a nominal class. Turning them off also means that no header information will be stored if the machine is linear. */ private boolean m_checksTurnedOff; /** Precision constant for updating sets */ private static double m_Del = 1000 * Double.MIN_VALUE; /** Whether logistic models are to be fit */ private boolean m_fitLogisticModels = false; /** The number of folds for the internal cross-validation */ private int m_numFolds = -1; /** The random number seed for the internal cross-validation */ private int m_randomSeed = 1; /** * Turns off checks for missing values, etc. Use with caution. */ public void turnChecksOff() { m_checksTurnedOff = true; } /** * Turns on checks for missing values, etc. */ public void turnChecksOn() { m_checksTurnedOff = false; } /** * Method for building the classifier. Implements a one-against-one * wrapper for multi-class problems. * * @param insts the set of training instances * @exception Exception if the classifier can't be built successfully */ public void buildClassifier(Instances insts) throws Exception { if (!m_checksTurnedOff) { if (insts.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (insts.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("SMO can't handle a numeric class!"); } insts = new Instances(insts); insts.deleteWithMissingClass(); if (insts.numInstances() == 0) { throw new Exception("No training instances without a missing class!"); } } m_onlyNumeric = true; if (!m_checksTurnedOff) { for (int i = 0; i < insts.numAttributes(); i++) { if (i != insts.classIndex()) { if (!insts.attribute(i).isNumeric()) { m_onlyNumeric = false; break; } } } } if (!m_checksTurnedOff) { m_Missing = new ReplaceMissingValues(); m_Missing.setInputFormat(insts); insts = Filter.useFilter(insts, m_Missing); } else { m_Missing = null; } if (!m_onlyNumeric) { m_NominalToBinary = new NominalToBinary(); m_NominalToBinary.setInputFormat(insts); insts = Filter.useFilter(insts, m_NominalToBinary); } else { m_NominalToBinary = null; } if (m_filterType == FILTER_STANDARDIZE) { m_Filter = new Standardize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } else if (m_filterType == FILTER_NORMALIZE) { m_Filter = new Normalize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } else { m_Filter = null; } m_classIndex = insts.classIndex(); m_classAttribute = insts.classAttribute(); // Generate subsets representing each class Instances[] subsets = new Instances[insts.numClasses()]; for (int i = 0; i < insts.numClasses(); i++) { subsets[i] = new Instances(insts, insts.numInstances()); } for (int j = 0; j < insts.numInstances(); j++) { Instance inst = insts.instance(j); subsets[(int)inst.classValue()].add(inst); } for (int i = 0; i < insts.numClasses(); i++) { subsets[i].compactify(); } // Build the binary classifiers m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()]; for (int i = 0; i < insts.numClasses(); i++) { for (int j = i + 1; j < insts.numClasses(); j++) { m_classifiers[i][j] = new BinarySMO(); Instances data = new Instances(insts, insts.numInstances()); for (int k = 0; k < subsets[i].numInstances(); k++) { data.add(subsets[i].instance(k)); } for (int k = 0; k < subsets[j].numInstances(); k++) { data.add(subsets[j].instance(k)); } data.compactify(); m_classifiers[i][j].buildClassifier(data, i, j, m_fitLogisticModels, m_numFolds, m_randomSeed); } } } /** * Estimates class probabilities for given instance. */ public double[] distributionForInstance(Instance inst) throws Exception { // Filter instance if (!m_checksTurnedOff) { m_Missing.input(inst); m_Missing.batchFinished(); inst = m_Missing.output(); } if (!m_onlyNumeric) { m_NominalToBinary.input(inst); m_NominalToBinary.batchFinished(); inst = m_NominalToBinary.output(); } if (m_Filter != null) { m_Filter.input(inst); m_Filter.batchFinished(); inst = m_Filter.output(); } if (!m_fitLogisticModels) { double[] result = new double[inst.numClasses()]; for (int i = 0; i < inst.numClasses(); i++) { for (int j = i + 1; j < inst.numClasses(); j++) { double output = m_classifiers[i][j].SVMOutput(-1, inst); if (output > 0) { result[j] += 1; } else { result[i] += 1; } } } Utils.normalize(result); return result; } else { // We only need to do pairwise coupling if there are more // then two classes. if (inst.numClasses() == 2) { double[] newInst = new double[2]; newInst[0] = m_classifiers[0][1].SVMOutput(-1, inst); newInst[1] = Instance.missingValue(); return m_classifiers[0][1].m_logistic. distributionForInstance(new Instance(1, newInst)); } double[][] r = new double[inst.numClasses()][inst.numClasses()]; double[][] n = new double[inst.numClasses()][inst.numClasses()]; for (int i = 0; i < inst.numClasses(); i++) { for (int j = i + 1; j < inst.numClasses(); j++) { double[] newInst = new double[2]; newInst[0] = m_classifiers[i][j].SVMOutput(-1, inst); newInst[1] = Instance.missingValue(); r[i][j] = m_classifiers[i][j].m_logistic. distributionForInstance(new Instance(1, newInst))[0]; n[i][j] = m_classifiers[i][j].m_sumOfWeights; } } return pairwiseCoupling(n, r); } } /** * Implements pairwise coupling. * * @param n the sum of weights used to train each model * @param r the probability estimate from each model * @return the coupled estimates */ public double[] pairwiseCoupling(double[][] n, double[][] r) { // Initialize p and u array double[] p = new double[r.length]; for (int i =0; i < p.length; i++) { p[i] = 1.0 / (double)p.length; } double[][] u = new double[r.length][r.length]; for (int i = 0; i < r.length; i++) { for (int j = i + 1; j < r.length; j++) { u[i][j] = 0.5; } } // firstSum doesn't change double[] firstSum = new double[p.length]; for (int i = 0; i < p.length; i++) { for (int j = i + 1; j < p.length; j++) { firstSum[i] += n[i][j] * r[i][j]; firstSum[j] += n[i][j] * (1 - r[i][j]); } } // Iterate until convergence boolean changed; do { changed = false; double[] secondSum = new double[p.length]; for (int i = 0; i < p.length; i++) { for (int j = i + 1; j < p.length; j++) { secondSum[i] += n[i][j] * u[i][j]; secondSum[j] += n[i][j] * (1 - u[i][j]); } } for (int i = 0; i < p.length; i++) { if ((firstSum[i] == 0) || (secondSum[i] == 0)) { if (p[i] > 0) { changed = true; } p[i] = 0; } else { double factor = firstSum[i] / secondSum[i]; double pOld = p[i]; p[i] *= factor; if (Math.abs(pOld - p[i]) > 1.0e-3) { changed = true; } } } Utils.normalize(p); for (int i = 0; i < r.length; i++) { for (int j = i + 1; j < r.length; j++) { u[i][j] = p[i] / (p[i] + p[j]); } } } while (changed); return p; } /** * Returns an array of votes for the given instance. * @param inst the instance * @return array of votex * @exception Exception if something goes wrong */ public int[] obtainVotes(Instance inst) throws Exception { // Filter instance if (!m_checksTurnedOff) { m_Missing.input(inst); m_Missing.batchFinished(); inst = m_Missing.output(); } if (!m_onlyNumeric) { m_NominalToBinary.input(inst); m_NominalToBinary.batchFinished(); inst = m_NominalToBinary.output(); } if (m_Filter != null) { m_Filter.input(inst); m_Filter.batchFinished(); inst = m_Filter.output(); } int[] votes = new int[inst.numClasses()]; for (int i = 0; i < inst.numClasses(); i++) { for (int j = i + 1; j < inst.numClasses(); j++) { double output = m_classifiers[i][j].SVMOutput(-1, inst); if (output > 0) { votes[j] += 1; } else { votes[i] += 1; } } } return votes; } /** * Returns the coefficients in sparse format. Throws an exception * if there is more than one machine or if the machine is not * linear. */ public FastVector weights() throws Exception { if (m_classifiers.length > 2) { throw new Exception("More than one machine has been built."); } if (m_classifiers[0][1].m_sparseWeights == null) { throw new Exception("No weight vector available."); } FastVector vec = new FastVector(2); vec.addElement(m_classifiers[0][1].m_sparseWeights); vec.addElement(m_classifiers[0][1].m_sparseIndices);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -