📄 smoreg.java
字号:
/** * Get the value of epsilon. * @return Value of epsilon. */ public double getEpsilon() { return m_epsilon; } /** * Set the value of epsilon. * @param v Value to assign to epsilon. */ public void setEpsilon(double v) { m_epsilon = v; } /** * Turns off checks for missing values, etc. Use with caution. */ public void turnChecksOff() { m_checksTurnedOff = true; } /** * Turns on checks for missing values, etc. */ public void turnChecksOn() { m_checksTurnedOff = false; } /** * Prints out the classifier. * * @return a description of the classifier as a string */ public String toString() { StringBuffer text = new StringBuffer(); int printed = 0; if ((m_alpha == null) && (m_sparseWeights == null)) { return "SMOreg : No model built yet."; } try { text.append("SMOreg\n\n"); text.append("Kernel used:\n " + m_kernel.toString() + "\n\n"); // display the linear transformation String trans = ""; if (m_filterType == FILTER_STANDARDIZE) { //text.append("LINEAR TRANSFORMATION APPLIED : \n"); trans = "(standardized) "; //text.append(trans + m_data.classAttribute().name() + " = " + // m_Alin + " * " + m_data.classAttribute().name() + " + " + m_Blin + "\n\n"); } else if (m_filterType == FILTER_NORMALIZE) { //text.append("LINEAR TRANSFORMATION APPLIED : \n"); trans = "(normalized) "; //text.append(trans + m_data.classAttribute().name() + " = " + // m_Alin + " * " + m_data.classAttribute().name() + " + " + m_Blin + "\n\n"); } // If machine linear, print weight vector if (m_KernelIsLinear) { text.append("Machine Linear: showing attribute weights, "); text.append("not support vectors.\n"); // We can assume that the weight vector is stored in sparse // format because the classifier has been built text.append(trans + m_data.classAttribute().name() + " =\n"); for (int i = 0; i < m_sparseWeights.length; i++) { if (m_sparseIndices[i] != (int)m_classIndex) { if (printed > 0) { text.append(" + "); } else { text.append(" "); } text.append(Utils.doubleToString(m_sparseWeights[i], 12, 4) + " * "); if (m_filterType == FILTER_STANDARDIZE) { text.append("(standardized) "); } else if (m_filterType == FILTER_NORMALIZE) { text.append("(normalized) "); } if (!m_checksTurnedOff) { text.append(m_data.attribute(m_sparseIndices[i]).name()+"\n"); } else { text.append("attribute with index " + m_sparseIndices[i] +"\n"); } printed++; } } } else { text.append("Support Vector Expansion :\n"); text.append(trans + m_data.classAttribute().name() + " =\n"); printed = 0; for (int i = 0; i < m_alpha.length; i++) { double val = m_alpha[i] - m_alpha_[i]; if (java.lang.Math.abs(val) < 1e-4) continue; if (printed > 0) { text.append(" + "); } else { text.append(" "); } text.append(Utils.doubleToString(val, 12, 4) + " * K[X(" + i + "), X]\n"); printed++; } } if (m_b > 0) { text.append(" + " + Utils.doubleToString(m_b, 12, 4)); } else { text.append(" - " + Utils.doubleToString(-m_b, 12, 4)); } if (!m_KernelIsLinear) { text.append("\n\nNumber of support vectors: " + printed); } int numEval = 0; int numCacheHits = -1; if(m_kernel != null) { numEval = m_kernel.numEvals(); numCacheHits = m_kernel.numCacheHits(); } text.append("\n\nNumber of kernel evaluations: " + numEval); if (numCacheHits >= 0 && numEval > 0) { double hitRatio = 1 - numEval/(numCacheHits+numEval); text.append(" (" + Utils.doubleToString(hitRatio*100, 7, 3) + "% cached)"); } } catch (Exception e) { return "Can't print the classifier."; } return text.toString(); } /** * Main method for testing this class. * * @param argv the commandline options */ public static void main(String[] argv) { Classifier scheme; try { scheme = new SMOreg(); System.out.println(Evaluation.evaluateModel(scheme, argv)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } } /** * Debuggage function. * Compute the value of the objective function. * * @return the value of the objective function * @throws Exception if computation fails */ protected double objFun() throws Exception { double res = 0; double t = 0, t2 = 0; for(int i = 0; i < m_alpha.length; i++){ for(int j = 0; j < m_alpha.length; j++){ t += (m_alpha[i] - m_alpha_[i]) * (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i,j,m_data.instance(i)); } t2 += m_data.instance(i).classValue() * (m_alpha[i] - m_alpha_[i]) - m_epsilon * (m_alpha[i] + m_alpha_[i]); } res += -0.5 * t + t2; return res; } /** * Debuggage function. * Compute the value of the objective function. * * @param i1 * @param i2 * @param alpha1 * @param alpha1_ * @param alpha2 * @param alpha2_ * @throws Exception if something goes wrong */ protected double objFun(int i1, int i2, double alpha1, double alpha1_, double alpha2, double alpha2_) throws Exception { double res = 0; double t = 0, t2 = 0; for(int i = 0; i < m_alpha.length; i++){ double alphai; double alphai_; if(i == i1){ alphai = alpha1; alphai_ = alpha1_; } else if(i == i2){ alphai = alpha2; alphai_ = alpha2_; } else { alphai = m_alpha[i]; alphai_ = m_alpha_[i]; } for(int j = 0; j < m_alpha.length; j++){ double alphaj; double alphaj_; if(j == i1){ alphaj = alpha1; alphaj_ = alpha1_; } else if(j == i2){ alphaj = alpha2; alphaj_ = alpha2_; } else { alphaj = m_alpha[j]; alphaj_ = m_alpha_[j]; } t += (alphai - alphai_) * (alphaj - alphaj_) * m_kernel.eval(i,j,m_data.instance(i)); } t2 += m_data.instance(i).classValue() * (alphai - alphai_) - m_epsilon * (alphai + alphai_); } res += -0.5 * t + t2; return res; } /** * Debuggage function. * Check that the set I0, I1, I2 and I3 cover the whole set of index * and that no attribute appears in two different sets. * * @throws Exception if check fails */ protected void checkSets() throws Exception{ boolean[] test = new boolean[m_data.numInstances()]; for (int i = m_I0.getNext(-1); i != -1; i = m_I0.getNext(i)) { if(test[i]){ throw new Exception("Fatal error! indice " + i + " appears in two different sets."); } else { test[i] = true; } if( !((0 < m_alpha[i] && m_alpha[i] < m_C * m_data.instance(i).weight()) || (0 < m_alpha_[i] && m_alpha_[i] < m_C * m_data.instance(i).weight())) ){ throw new Exception("Warning! I0 contains an incorrect indice."); } } for (int i = m_I1.getNext(-1); i != -1; i = m_I1.getNext(i)) { if(test[i]){ throw new Exception("Fatal error! indice " + i + " appears in two different sets."); } else { test[i] = true; } if( !( m_alpha[i] == 0 && m_alpha_[i] == 0) ){ throw new Exception("Fatal error! I1 contains an incorrect indice."); } } for (int i = m_I2.getNext(-1); i != -1; i = m_I2.getNext(i)) { if(test[i]){ throw new Exception("Fatal error! indice " + i + " appears in two different sets."); } else { test[i] = true; } if( !(m_alpha[i] == 0 && m_alpha_[i] == m_C * m_data.instance(i).weight()) ){ throw new Exception("Fatal error! I2 contains an incorrect indice."); } } for (int i = m_I3.getNext(-1); i != -1; i = m_I3.getNext(i)) { if(test[i]){ throw new Exception("Fatal error! indice " + i + " appears in two different sets."); } else { test[i] = true; } if( !(m_alpha_[i] == 0 && m_alpha[i] == m_C * m_data.instance(i).weight()) ){ throw new Exception("Fatal error! I3 contains an incorrect indice."); } } for (int i = 0; i < test.length; i++){ if(!test[i]){ throw new Exception("Fatal error! indice " + i + " doesn't belong to any set."); } } } /** * Debuggage function <br/> * Checks that : <br/> * alpha*alpha_=0 <br/> * sum(alpha[i] - alpha_[i]) = 0 * * @throws Exception if check fails */ protected void checkAlphas() throws Exception{ double sum = 0; for(int i = 0; i < m_alpha.length; i++){ if(!(0 == m_alpha[i] || m_alpha_[i] == 0)){ throw new Exception("Fatal error! Inconsistent alphas!"); } sum += (m_alpha[i] - m_alpha_[i]); } if(sum > 1e-10){ throw new Exception("Fatal error! Inconsistent alphas' sum = " + sum); } } /** * Debuggage function. * Display the current status of the program. * @param i1 the first current indice * @param i2 the second current indice * * @throws Exception if printing of current status fails */ protected void displayStat(int i1, int i2) throws Exception { System.err.println("\n-------- Status : ---------"); System.err.println("\n i, alpha, alpha'\n"); for(int i = 0; i < m_alpha.length; i++){ double result = (m_bLow + m_bUp)/2.0; for (int j = 0; j < m_alpha.length; j++) { result += (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i, j, m_data.instance(i)); } System.err.print(" " + i + ": (" + m_alpha[i] + ", " + m_alpha_[i] + "), " + (m_data.instance(i).classValue() - m_epsilon) + " <= " + result + " <= " + (m_data.instance(i).classValue() + m_epsilon)); if(i == i1){ System.err.print(" <-- i1"); } if(i == i2){ System.err.print(" <-- i2"); } System.err.println(); } System.err.println("bLow = " + m_bLow + " bUp = " + m_bUp); System.err.println("---------------------------\n"); } /** * Debuggage function * Compute and display bLow, lUp and so on... * * @throws Exception if display fails */ protected void displayB() throws Exception { //double bUp = Double.NEGATIVE_INFINITY; //double bLow = Double.POSITIVE_INFINITY; //int iUp = -1, iLow = -1; for(int i = 0; i < m_data.numInstances(); i++){ double Fi = m_data.instance(i).classValue(); for(int j = 0; j < m_alpha.length; j++){ Fi -= (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i, j, m_data.instance(i)); } System.err.print("(" + m_alpha[i] + ", " + m_alpha_[i] + ") : "); System.err.print((Fi - m_epsilon) + ", " + (Fi + m_epsilon)); double fim = Fi - m_epsilon, fip = Fi + m_epsilon; String s = ""; if (m_I0.contains(i)){ if ( 0 < m_alpha[i] && m_alpha[i] < m_C * m_data.instance(i).weight()){ s += "(in I0a) bUp = min(bUp, " + fim + ") bLow = max(bLow, " + fim + ")"; } if ( 0 < m_alpha_[i] && m_alpha_[i] < m_C * m_data.instance(i).weight()){ s += "(in I0a) bUp = min(bUp, " + fip + ") bLow = max(bLow, " + fip + ")"; } } if (m_I1.contains(i)){ s += "(in I1) bUp = min(bUp, " + fip + ") bLow = max(bLow, " + fim + ")"; } if (m_I2.contains(i)){ s += "(in I2) bLow = max(bLow, " + fip + ")"; } if (m_I3.contains(i)){ s += "(in I3) bUp = min(bUp, " + fim + ")"; } System.err.println(" " + s + " {" + (m_alpha[i]-1) + ", " + (m_alpha_[i]-1) + "}"); } System.err.println("\n\n"); } /** * Debuggage function. * Checks if the equations (6), (8a), (8b), (8c), (8d) hold. * (Refers to "Improvements to SMO Algorithm for SVM Regression".) * Prints warnings for each equation which doesn't hold. * * @throws Exception if check fails */ protected void checkOptimality() throws Exception { double bUp = Double.POSITIVE_INFINITY; double bLow = Double.NEGATIVE_INFINITY; int iUp = -1, iLow = -1; for(int i = 0; i < m_data.numInstances(); i++){ double Fi = m_data.instance(i).classValue(); for(int j = 0; j < m_alpha.length; j++){ Fi -= (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i, j, m_data.instance(i)); } double fitilde = 0, fibarre = 0; if(m_I0.contains(i) && 0 < m_alpha[i] && m_alpha[i] < m_C * m_data.instance(i).weight()){ fitilde = Fi - m_epsilon; fibarre = Fi - m_epsilon; } if(m_I0.contains(i) && 0 < m_alpha_[i] && m_alpha_[i] < m_C * m_data.instance(i).weight()){ fitilde = Fi + m_epsilon; fibarre = Fi + m_epsilon; } if(m_I1.contains(i)){ fitilde = Fi - m_epsilon; fibarre = Fi + m_epsilon; } if(m_I2.contains(i)){ fitilde = Fi + m_epsilon; fibarre = Double.POSITIVE_INFINITY; } if(m_I3.contains(i)){ fitilde = Double.NEGATIVE_INFINITY; fibarre = Fi - m_epsilon; } if(fibarre < bUp){ bUp = fibarre; iUp = i; } if(fitilde > bLow){ bLow = fitilde; iLow = i; } } if(!(bLow <= bUp + 2 * m_tol)){ System.err.println("Warning! Optimality not reached : inequation (6) doesn't hold!"); } boolean noPb = true; for(int i = 0; i < m_data.numInstances(); i++){ double Fi = m_data.instance(i).classValue(); for(int j = 0; j < m_alpha.length; j++){ Fi -= (m_alpha[j] - m_alpha_[j]) * m_kernel.eval(i, j, m_data.instance(i)); } double Ei = Fi - ((m_bUp + m_bLow) / 2.0); if((m_alpha[i] > 0) && !(Ei >= m_epsilon - m_tol)){ System.err.println("Warning! Optimality not reached : inequation (8a) doesn't hold for " + i); noPb = false; } if((m_alpha[i] < m_C * m_data.instance(i).weight()) && !(Ei <= m_epsilon + m_tol)){ System.err.println("Warning! Optimality not reached : inequation (8b) doesn't hold for " + i); noPb = false; } if((m_alpha_[i] > 0) && !(Ei <= -m_epsilon + m_tol)){ System.err.println("Warning! Optimality not reached : inequation (8c) doesn't hold for " + i); noPb = false; } if((m_alpha_[i] < m_C * m_data.instance(i).weight()) && !(Ei >= -m_epsilon - m_tol)){ System.err.println("Warning! Optimality not reached : inequation (8d) doesn't hold for " + i); noPb = false; } } if(!noPb){ System.err.println(); //displayStat(-1,-1); //displayB(); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -