📄 smo.java
字号:
double[] vals = new double[2]; vals[0] = smo.SVMOutput(-1, test.instance(j)); if (test.instance(j).classValue() == cl2) { vals[1] = 1; } data.add(new Instance(test.instance(j).weight(), vals)); } } } // Build logistic regression model m_logistic = new Logistic(); m_logistic.buildClassifier(data); } /** * Method for building the binary classifier. * * @param insts the set of training instances * @param cl1 the first class' index * @param cl2 the second class' index * @param fitLogistic true if logistic model is to be fit * @param numFolds number of folds for internal cross-validation * @param random random number generator for cross-validation * @exception Exception if the classifier can't be built successfully */ private void buildClassifier(Instances insts, int cl1, int cl2, boolean fitLogistic, int numFolds, int randomSeed) throws Exception { // Initialize the number of kernel evaluations m_kernelEvals = 0; // Initialize thresholds m_bUp = -1; m_bLow = 1; m_b = 0; // Store the sum of weights m_sumOfWeights = insts.sumOfWeights(); // Set class values m_class = new double[insts.numInstances()]; m_iUp = -1; m_iLow = -1; for (int i = 0; i < m_class.length; i++) { if ((int) insts.instance(i).classValue() == cl1) { m_class[i] = -1; m_iLow = i; } else if ((int) insts.instance(i).classValue() == cl2) { m_class[i] = 1; m_iUp = i; } else { throw new Exception ("This should never happen!"); } } if ((m_iUp == -1) || (m_iLow == -1)) { if (m_iUp == -1) { m_b = 1; } else { m_b = -1; } if (!m_useRBF && m_exponent == 1.0) { m_sparseWeights = new double[0]; m_sparseIndices = new int[0]; } m_class = null; // Fit sigmoid if requested if (fitLogistic) { fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed)); } return; } // Set the reference to the data m_data = insts; // If machine is linear, reserve space for weights if (!m_useRBF && m_exponent == 1.0) { m_weights = new double[m_data.numAttributes()]; } else { m_weights = null; } // Initialize alpha array to zero m_alpha = new double[m_data.numInstances()]; // Initialize sets m_supportVectors = new SMOset(m_data.numInstances()); m_I0 = new SMOset(m_data.numInstances()); m_I1 = new SMOset(m_data.numInstances()); m_I2 = new SMOset(m_data.numInstances()); m_I3 = new SMOset(m_data.numInstances()); m_I4 = new SMOset(m_data.numInstances()); // Clean out some instance variables m_sparseWeights = null; m_sparseIndices = null; // Initialize error cache m_errors = new double[m_data.numInstances()]; m_errors[m_iLow] = 1; m_errors[m_iUp] = -1; // Initialize kernel if(m_useRBF) { m_kernel = new RBFKernel(m_data); } else { if (m_featureSpaceNormalization) { m_kernel = new NormalizedPolyKernel(); } else { m_kernel = new PolyKernel(); } } // The kernel calculations are cached m_storage = new double[m_cacheSize]; m_keys = new long[m_cacheSize]; // Build up I1 and I4 for (int i = 0; i < m_class.length; i++ ) { if (m_class[i] == 1) { m_I1.insert(i); } else { m_I4.insert(i); } } // Loop to find all the support vectors int numChanged = 0; boolean examineAll = true; while ((numChanged > 0) || examineAll) { numChanged = 0; if (examineAll) { for (int i = 0; i < m_alpha.length; i++) { if (examineExample(i)) { numChanged++; } } } else { // This code implements Modification 1 from Keerthi et al.'s paper for (int i = 0; i < m_alpha.length; i++) { if ((m_alpha[i] > 0) && (m_alpha[i] < m_C * m_data.instance(i).weight())) { if (examineExample(i)) { numChanged++; } // Is optimality on unbound vectors obtained? if (m_bUp > m_bLow - 2 * m_tol) { numChanged = 0; break; } } } //This is the code for Modification 2 from Keerthi et al.'s paper /*boolean innerLoopSuccess = true; numChanged = 0; while ((m_bUp < m_bLow - 2 * m_tol) && (innerLoopSuccess == true)) { innerLoopSuccess = takeStep(m_iUp, m_iLow, m_errors[m_iLow]); }*/ } if (examineAll) { examineAll = false; } else if (numChanged == 0) { examineAll = true; } } // Set threshold m_b = (m_bLow + m_bUp) / 2.0; // Save memory m_storage = null; m_keys = null; m_errors = null; m_I0 = m_I1 = m_I2 = m_I3 = m_I4 = null; // If machine is linear, delete training data // and store weight vector in sparse format if (!m_useRBF && m_exponent == 1.0) { // We don't need to store the set of support vectors m_supportVectors = null; // We don't need to store the class values either m_class = null; // Clean out training data if (!m_checksTurnedOff) { m_data = new Instances(m_data, 0); } else { m_data = null; } // Convert weight vector double[] sparseWeights = new double[m_weights.length]; int[] sparseIndices = new int[m_weights.length]; int counter = 0; for (int i = 0; i < m_weights.length; i++) { if (m_weights[i] != 0.0) { sparseWeights[counter] = m_weights[i]; sparseIndices[counter] = i; counter++; } } m_sparseWeights = new double[counter]; m_sparseIndices = new int[counter]; System.arraycopy(sparseWeights, 0, m_sparseWeights, 0, counter); System.arraycopy(sparseIndices, 0, m_sparseIndices, 0, counter); // Clean out weight vector m_weights = null; // We don't need the alphas in the linear case m_alpha = null; } // Fit sigmoid if requested if (fitLogistic) { fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed)); } } /** * Computes SVM output for given instance. * * @param index the instance for which output is to be computed * @param inst the instance * @return the output of the SVM for the given instance */ private double SVMOutput(int index, Instance inst) throws Exception { double result = 0; // Is the machine linear? if (!m_useRBF && m_exponent == 1.0) { // Is weight vector stored in sparse format? if (m_sparseWeights == null) { int n1 = inst.numValues(); for (int p = 0; p < n1; p++) { if (inst.index(p) != m_classIndex) { result += m_weights[inst.index(p)] * inst.valueSparse(p); } } } else { int n1 = inst.numValues(); int n2 = m_sparseWeights.length; for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) { int ind1 = inst.index(p1); int ind2 = m_sparseIndices[p2]; if (ind1 == ind2) { if (ind1 != m_classIndex) { result += inst.valueSparse(p1) * m_sparseWeights[p2]; } p1++; p2++; } else if (ind1 > ind2) { p2++; } else { p1++; } } } } else { for (int i = m_supportVectors.getNext(-1); i != -1; i = m_supportVectors.getNext(i)) { result += m_class[i] * m_alpha[i] * m_kernel.eval(index, i, inst); } } result -= m_b; return result; } /** * Prints out the classifier. * * @return a description of the classifier as a string */ public String toString() { StringBuffer text = new StringBuffer(); int printed = 0; if ((m_alpha == null) && (m_sparseWeights == null)) { return "BinarySMO: No model built yet."; } try { text.append("BinarySMO\n\n"); // If machine linear, print weight vector if (!m_useRBF && m_exponent == 1.0) { text.append("Machine linear: showing attribute weights, "); text.append("not support vectors.\n\n"); // We can assume that the weight vector is stored in sparse // format because the classifier has been built for (int i = 0; i < m_sparseWeights.length; i++) { if (i != (int)m_classIndex) { if (printed > 0) { text.append(" + "); } else { text.append(" "); } text.append(Utils.doubleToString(m_sparseWeights[i], 12, 4) + " * "); if (m_filterType == FILTER_STANDARDIZE) { text.append("(standardized) "); } else if (m_filterType == FILTER_NORMALIZE) { text.append("(normalized) "); } if (!m_checksTurnedOff) { text.append(m_data.attribute(m_sparseIndices[i]).name()+"\n"); } else { text.append("attribute with index " + m_sparseIndices[i] +"\n"); } printed++; } } } else { for (int i = 0; i < m_alpha.length; i++) { if (m_supportVectors.contains(i)) { double val = m_alpha[i]; if (m_class[i] == 1) { if (printed > 0) { text.append(" + "); } } else { text.append(" - "); } text.append(Utils.doubleToString(val, 12, 4) + " * K[X(" + i + ") * X]\n"); printed++; } } } if (m_b > 0) { text.append(" - " + Utils.doubleToString(m_b, 12, 4)); } else { text.append(" + " + Utils.doubleToString(-m_b, 12, 4)); } if (m_useRBF || m_exponent != 1.0) { text.append("\n\nNumber of support vectors: " + m_supportVectors.numElements()); } text.append("\n\nNumber of kernel evaluations: " + m_kernelEvals); } catch (Exception e) { return "Can't print BinarySMO classifier."; } return text.toString(); } /** * Examines instance. * * @param i2 index of instance to examine * @return true if examination was successfull * @exception Exception if something goes wrong */ private boolean examineExample(int i2) throws Exception { double y2, alph2, F2; int i1 = -1; y2 = m_class[i2]; alph2 = m_alpha[i2]; if (m_I0.contains(i2)) { F2 = m_errors[i2]; } else { F2 = SVMOutput(i2, m_data.instance(i2)) + m_b - y2; m_errors[i2] = F2; // Update thresholds if ((m_I1.contains(i2) || m_I2.contains(i2)) && (F2 < m_bUp)) { m_bUp = F2; m_iUp = i2; } else if ((m_I3.contains(i2) || m_I4.contains(i2)) && (F2 > m_bLow)) { m_bLow = F2; m_iLow = i2; } } // Check optimality using current bLow and bUp and, if // violated, find an index i1 to do joint optimization // with i2... boolean optimal = true; if (m_I0.contains(i2) || m_I1.contains(i2) || m_I2.contains(i2)) { if (m_bLow - F2 > 2 * m_tol) { optimal = false; i1 = m_iLow; } } if (m_I0.contains(i2) || m_I3.contains(i2) || m_I4.contains(i2)) { if (F2 - m_bUp > 2 * m_tol) { optimal = false; i1 = m_iUp; } } if (optimal) { return false; } // For i2 unbound choose the better i1... if (m_I0.contains(i2)) { if (m_bLow - F2 > F2 - m_bUp) { i1 = m_iLow; } else { i1 = m_iUp; } } if (i1 == -1) { throw new Exception("This should never happen!"); } return takeStep(i1, i2, F2); } /** * Method solving for the Lagrange multipliers for * two instances. * * @param i1 index of the first instance * @param i2 index of the second instance * @return true if multipliers could be found * @exception Exception if something goes wrong */ private boolean takeStep(int i1, int i2, double F2) throws Exception { double alph1, alph2, y1, y2, F1, s, L, H, k11, k12, k22, eta, a1, a2, f1, f2, v1, v2, Lobj, Hobj, b1, b2, bOld; double C1 = m_C * m_data.instance(i1).weight(); double C2 = m_C * m_data.instance(i2).weight(); // Don't do anything if the two instances are the same if (i1 == i2) { return false; } // Initialize variables alph1 = m_alpha[i1]; alph2 = m_alpha[i2]; y1 = m_class[i1]; y2 = m_class[i2]; F1 = m_errors[i1]; s = y1 * y2; // Find the constraints on a2 if (y1 != y2) { L = Math.max(0, alph2 - alph1); H = Math.min(C2, C1 + alph2 - alph1); } else { L = Math.max(0, alph1 + alph2 - C1); H = Math.min(C2, alph1 + alph2); } if (L >= H) { return false; } // Compute second derivative of objective function k11 = m_kernel.eval(i1, i1, m_data.instance(i1)); k12 = m_kernel.eval(i1, i2, m_data.instance(i1)); k22 = m_kernel.eval(i2, i2, m_data.instance(i2)); eta = 2 * k12 - k11 - k22; // Check if second derivative is negative if (eta < 0) { // Compute unconstrained maximum a2 = alph2 - y2 * (F1 - F2) / eta; // Compute constrained maximum if (a2 < L) { a2 = L; } else if (a2 > H) { a2 = H; } } else { // Look at endpoints of diagonal f1 = SVMOutput(i1, m_data.instance(i1)); f2 = SVMOutput(i2, m_data.instance(i2)); v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12; v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22; double gamma = alph1 + s * alph2; Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) - 0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L - y1 * (gamma - s * L) * v1 - y2 * L * v2; Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) - 0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H - y1 * (gamma - s * H) * v1 - y2 * H * v2; if (Lobj > Hobj + m_eps) { a2 = L; } else if (Lobj < Hobj - m_eps) { a2 = H; } else { a2 = alph2; } } if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) { return false; } // To prevent precision problems if (a2 > C2 - m_Del * C2) { a2 = C2; } else if (a2 <= m_Del * C2) { a2 = 0; } // Recompute a1 a1 = alph1 + s * (alph2 - a2); // To prevent precision problems if (a1 > C1 - m_Del * C1) { a1 = C1; } else if (a1 <= m_Del * C1) { a1 = 0; } // Update sets if (a1 > 0) { m_supportVectors.insert(i1); } else { m_supportVectors.delete(i1); } if ((a1 > 0) && (a1 < C1)) { m_I0.insert(i1); } else { m_I0.delete(i1); } if ((y1 == 1) && (a1 == 0)) { m_I1.insert(i1); } else { m_I1.delete(i1); } if ((y1 == -1) && (a1 == C1)) { m_I2.insert(i1); } else { m_I2.delete(i1); } if ((y1 == 1) && (a1 == C1)) { m_I3.insert(i1);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -