⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.java

📁 数据挖掘classifiers算法
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
   * @param v  Value to assign to lowerOrder.   */  public void setLowerOrderTerms(boolean v) {        if (m_exponent == 1.0) {      m_lowerOrder = false;    } else {      m_lowerOrder = v;    }  }  /**   * Computes the result of the kernel function for two instances.   *   * @param id1 the index of the first instance   * @param id2 the index of the second instance   * @param inst the instance corresponding to id1   * @return the result of the kernel function   */  private double kernel(int id1, int id2, Instance inst1) throws Exception {    double result = 0;    long key = -1;    int location = -1;    // we can only cache if we know the indexes    if (id1 >= 0) {      if (id1 > id2) {	key = (long)id1 * m_alpha.length + id2;      } else {	key = (long)id2 * m_alpha.length + id1;      }      if (key < 0) {	throw new Exception("Cache overflow detected!");      }      location = (int)(key % m_keys.length);      if (m_keys[location] == (key + 1)) {	return m_storage[location];      }    }	    // we can do a fast dot product    Instance inst2 = m_data.instance(id2);    int n1 = inst1.numValues(); int n2 = inst2.numValues();    int classIndex = m_data.classIndex();    for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {      int ind1 = inst1.index(p1);       int ind2 = inst2.index(p2);      if (ind1 == ind2) {	if (ind1 != classIndex) {	  result += inst1.valueSparse(p1) * inst2.valueSparse(p2);	}	p1++; p2++;      } else if (ind1 > ind2) {	p2++;      } else { 	p1++;      }    }        // Use lower order terms?    if (m_lowerOrder) {      result += 1.0;    }    // Rescale kernel?    if (m_rescale) {      result /= (double)m_data.numAttributes() - 1;    }              if (m_exponent != 1.0) {      result = Math.pow(result, m_exponent);    }    m_kernelEvals++;        // store result in cache 	    if (key != -1){      m_storage[location] = result;      m_keys[location] = (key + 1);    }    return result;  }  /**   * Examines instance.   *   * @param i2 index of instance to examine   * @return true if examination was successfull   * @exception Exception if something goes wrong   */  private boolean examineExample(int i2) throws Exception {        double y2, alph2, F2;    int i1 = -1;        y2 = m_class[i2];    alph2 = m_alpha[i2];    if (m_I0.contains(i2)) {      F2 = m_errors[i2];    } else {      F2 = SVMOutput(i2, m_data.instance(i2)) + m_b - y2;      m_errors[i2] = F2;            // Update thresholds      if ((m_I1.contains(i2) || m_I2.contains(i2)) && (F2 < m_bUp)) {	m_bUp = F2; m_iUp = i2;      } else if ((m_I3.contains(i2) || m_I4.contains(i2)) && (F2 > m_bLow)) {	m_bLow = F2; m_iLow = i2;      }    }    // Check optimality using current bLow and bUp and, if    // violated, find an index i1 to do joint optimization    // with i2...    boolean optimal = true;    if (m_I0.contains(i2) || m_I1.contains(i2) || m_I2.contains(i2)) {      if (m_bLow - F2 > 2 * m_tol) {	optimal = false; i1 = m_iLow;      }    }    if (m_I0.contains(i2) || m_I3.contains(i2) || m_I4.contains(i2)) {      if (F2 - m_bUp > 2 * m_tol) {	optimal = false; i1 = m_iUp;      }    }    if (optimal) {      return false;    }    // For i2 unbound choose the better i1...    if (m_I0.contains(i2)) {      if (m_bLow - F2 > F2 - m_bUp) {	i1 = m_iLow;      } else {	i1 = m_iUp;      }    }    if (i1 == -1) {      throw new Exception("This should never happen!");    }    return takeStep(i1, i2, F2);  }  /**   * Method solving for the Lagrange multipliers for   * two instances.   *   * @param i1 index of the first instance   * @param i2 index of the second instance   * @return true if multipliers could be found   * @exception Exception if something goes wrong   */  private boolean takeStep(int i1, int i2, double F2) throws Exception {    double alph1, alph2, y1, y2, F1, s, L, H, k11, k12, k22, eta,      a1, a2, f1, f2, v1, v2, Lobj, Hobj, b1, b2, bOld;    // Don't do anything if the two instances are the same    if (i1 == i2) {      return false;    }    // Initialize variables    alph1 = m_alpha[i1]; alph2 = m_alpha[i2];    y1 = m_class[i1]; y2 = m_class[i2];    F1 = m_errors[i1];    s = y1 * y2;    // Find the constraints on a2    if (y1 != y2) {      L = Math.max(0, alph2 - alph1);       H = Math.min(m_C, m_C + alph2 - alph1);    } else {      L = Math.max(0, alph1 + alph2 - m_C);      H = Math.min(m_C, alph1 + alph2);    }    if (L >= H) {             return false;    }    // Compute second derivative of objective function    k11 = kernel(i1, i1, m_data.instance(i1));    k12 = kernel(i1, i2, m_data.instance(i1));    k22 = kernel(i2, i2, m_data.instance(i2));    eta = 2 * k12 - k11 - k22;    // Check if second derivative is negative    if (eta < 0) {      // Compute unconstrained maximum      a2 = alph2 - y2 * (F1 - F2) / eta;      // Compute constrained maximum      if (a2 < L) {	a2 = L;      } else if (a2 > H) {	a2 = H;      }    } else {      // Look at endpoints of diagonal      f1 = SVMOutput(i1, m_data.instance(i1));      f2 = SVMOutput(i2, m_data.instance(i2));      v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12;       v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22;       double gamma = alph1 + s * alph2;      Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) - 	0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L - 	y1 * (gamma - s * L) * v1 - y2 * L * v2;      Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) - 	0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H - 	y1 * (gamma - s * H) * v1 - y2 * H * v2;      if (Lobj > Hobj + m_eps) {	a2 = L;      } else if (Lobj < Hobj - m_eps) {	a2 = H;      } else {	a2 = alph2;      }    }    if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) {      return false;    }    // To prevent precision problems    if (a2 > m_C - m_Del * m_C) {      a2 = m_C;    } else if (a2 <= m_Del * m_C) {      a2 = 0;    }    // Recompute a1    a1 = alph1 + s * (alph2 - a2);    // To prevent precision problems    if (a1 > m_C - m_Del * m_C) {      a1 = m_C;    } else if (a1 <= m_Del * m_C) {      a1 = 0;    }    // Update sets    if (a1 > 0) {      m_supportVectors.insert(i1);    } else {      m_supportVectors.delete(i1);    }    if ((a1 > 0) && (a1 < m_C)) {      m_I0.insert(i1);    } else {      m_I0.delete(i1);    }    if ((y1 == 1) && (a1 == 0)) {      m_I1.insert(i1);    } else {      m_I1.delete(i1);    }    if ((y1 == -1) && (a1 == m_C)) {      m_I2.insert(i1);    } else {      m_I2.delete(i1);    }    if ((y1 == 1) && (a1 == m_C)) {      m_I3.insert(i1);    } else {      m_I3.delete(i1);    }    if ((y1 == -1) && (a1 == 0)) {      m_I4.insert(i1);    } else {      m_I4.delete(i1);    }    if (a2 > 0) {      m_supportVectors.insert(i2);    } else {      m_supportVectors.delete(i2);    }    if ((a2 > 0) && (a2 < m_C)) {      m_I0.insert(i2);    } else {      m_I0.delete(i2);    }    if ((y2 == 1) && (a2 == 0)) {      m_I1.insert(i2);    } else {      m_I1.delete(i2);    }    if ((y2 == -1) && (a2 == m_C)) {      m_I2.insert(i2);    } else {      m_I2.delete(i2);    }    if ((y2 == 1) && (a2 == m_C)) {      m_I3.insert(i2);    } else {      m_I3.delete(i2);    }    if ((y2 == -1) && (a2 == 0)) {      m_I4.insert(i2);    } else {      m_I4.delete(i2);    }        // Update weight vector to reflect change a1 and a2, if linear SVM    if (m_exponent == 1.0) {      Instance inst1 = m_data.instance(i1);      for (int p1 = 0; p1 < inst1.numValues(); p1++) {	if (inst1.index(p1) != m_data.classIndex()) {	  m_weights[inst1.index(p1)] += 	    y1 * (a1 - alph1) * inst1.valueSparse(p1);	}      }      Instance inst2 = m_data.instance(i2);      for (int p2 = 0; p2 < inst2.numValues(); p2++) {	if (inst2.index(p2) != m_data.classIndex()) {	  m_weights[inst2.index(p2)] += 	    y2 * (a2 - alph2) * inst2.valueSparse(p2);	}      }    }    // Update error cache using new Lagrange multipliers    for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {      if ((j != i1) && (j != i2)) {	m_errors[j] += 	  y1 * (a1 - alph1) * kernel(i1, j, m_data.instance(i1)) + 	  y2 * (a2 - alph2) * kernel(i2, j, m_data.instance(i2));      }    }    // Update error cache for i1 and i2    m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;    m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;    // Update array with Lagrange multipliers    m_alpha[i1] = a1;    m_alpha[i2] = a2;          // Update thresholds    m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE;    m_iLow = -1; m_iUp = -1;    for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {      if (m_errors[j] < m_bUp) {	m_bUp = m_errors[j]; m_iUp = j;      }      if (m_errors[j] > m_bLow) {	m_bLow = m_errors[j]; m_iLow = j;      }    }    if (!m_I0.contains(i1)) {      if (m_I3.contains(i1) || m_I4.contains(i1)) {	if (m_errors[i1] > m_bLow) {	  m_bLow = m_errors[i1]; m_iLow = i1;	}       } else {	if (m_errors[i1] < m_bUp) {	  m_bUp = m_errors[i1]; m_iUp = i1;	}      }    }    if (!m_I0.contains(i2)) {      if (m_I3.contains(i2) || m_I4.contains(i2)) {	if (m_errors[i2] > m_bLow) {	  m_bLow = m_errors[i2]; m_iLow = i2;	}      } else {	if (m_errors[i2] < m_bUp) {	  m_bUp = m_errors[i2]; m_iUp = i2;	}      }    }    if ((m_iLow == -1) || (m_iUp == -1)) {      throw new Exception("This should never happen!");    }    // Made some progress.    return true;  }    /**   * Quick and dirty check whether the quadratic programming problem is solved.   */  private void checkClassifier() throws Exception {    double sum = 0;    for (int i = 0; i < m_alpha.length; i++) {      if (m_alpha[i] > 0) {	sum += m_class[i] * m_alpha[i];      }    }    System.err.println("Sum of y(i) * alpha(i): " + sum);    for (int i = 0; i < m_alpha.length; i++) {      double output = SVMOutput(i, m_data.instance(i));      if (Utils.eq(m_alpha[i], 0)) {	if (Utils.sm(m_class[i] * output, 1)) {	  System.err.println("KKT condition 1 violated: " + m_class[i] * output);	}      }       if (Utils.gr(m_alpha[i], 0) && Utils.sm(m_alpha[i], m_C)) {	if (!Utils.eq(m_class[i] * output, 1)) {	  System.err.println("KKT condition 2 violated: " + m_class[i] * output);	}      }       if (Utils.eq(m_alpha[i], m_C)) {	if (Utils.gr(m_class[i] * output, 1)) {	  System.err.println("KKT condition 3 violated: " + m_class[i] * output);	}      }     }  }    /**   * Main method for testing this class.   */  public static void main(String[] argv) {    Classifier scheme;    try {      scheme = new SMO();      System.out.println(Evaluation.evaluateModel(scheme, argv));    } catch (Exception e) {      System.err.println(e.getMessage());    }  }}    

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -