⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
	if (examineAll) {
	  examineAll = false;
	} else if (numChanged == 0) {
	  examineAll = true;
	}
      }
      
      // Set threshold
      m_b = (m_bLow + m_bUp) / 2.0;
      
      // Save memory
      m_kernel.clean(); 
      
      m_errors = null;
      m_I0 = m_I1 = m_I2 = m_I3 = m_I4 = null;
      
      // If machine is linear, delete training data
      // and store weight vector in sparse format
      if (!m_useRBF && m_exponent == 1.0) {
	
	// We don't need to store the set of support vectors
	m_supportVectors = null;

	// We don't need to store the class values either
	m_class = null;
	
	// Clean out training data
	if (!m_checksTurnedOff) {
	  m_data = new Instances(m_data, 0);
	} else {
	  m_data = null;
	}
	
	// Convert weight vector
	double[] sparseWeights = new double[m_weights.length];
	int[] sparseIndices = new int[m_weights.length];
	int counter = 0;
	for (int i = 0; i < m_weights.length; i++) {
	  if (m_weights[i] != 0.0) {
	    sparseWeights[counter] = m_weights[i];
	    sparseIndices[counter] = i;
	    counter++;
	  }
	}
	m_sparseWeights = new double[counter];
	m_sparseIndices = new int[counter];
	System.arraycopy(sparseWeights, 0, m_sparseWeights, 0, counter);
	System.arraycopy(sparseIndices, 0, m_sparseIndices, 0, counter);
	
	// Clean out weight vector
	m_weights = null;
	
	// We don't need the alphas in the linear case
	m_alpha = null;
      }
      
      // Fit sigmoid if requested
      if (fitLogistic) {
	fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed));
      }

    }
    
    /**
     * Computes SVM output for given instance.
     *
     * @param index the instance for which output is to be computed
     * @param inst the instance 
     * @return the output of the SVM for the given instance
     */
    protected double SVMOutput(int index, Instance inst) throws Exception {
      
      double result = 0;
      
      // Is the machine linear?
      if (!m_useRBF && m_exponent == 1.0) {
	
	// Is weight vector stored in sparse format?
	if (m_sparseWeights == null) {
	  int n1 = inst.numValues(); 
	  for (int p = 0; p < n1; p++) {
	    if (inst.index(p) != m_classIndex) {
	      result += m_weights[inst.index(p)] * inst.valueSparse(p);
	    }
	  }
	} else {
	  int n1 = inst.numValues(); int n2 = m_sparseWeights.length;
	  for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
	    int ind1 = inst.index(p1); 
	    int ind2 = m_sparseIndices[p2];
	    if (ind1 == ind2) {
	      if (ind1 != m_classIndex) {
		result += inst.valueSparse(p1) * m_sparseWeights[p2];
	      }
	      p1++; p2++;
	    } else if (ind1 > ind2) {
	      p2++;
	    } else { 
	      p1++;
	    }
	  }
	}
      } else {
	for (int i = m_supportVectors.getNext(-1); i != -1; 
	     i = m_supportVectors.getNext(i)) {
	  result += m_class[i] * m_alpha[i] * m_kernel.eval(index, i, inst);
	}
      }
      result -= m_b;
      
      return result;
    }

    /**
     * Prints out the classifier.
     *
     * @return a description of the classifier as a string
     */
    public String toString() {

      StringBuffer text = new StringBuffer();
      int printed = 0;

      if ((m_alpha == null) && (m_sparseWeights == null)) {
	return "BinarySMO: No model built yet.\n";
      }
      try {
	text.append("BinarySMO\n\n");

	// If machine linear, print weight vector
	if (!m_useRBF && m_exponent == 1.0) {
	  text.append("Machine linear: showing attribute weights, ");
	  text.append("not support vectors.\n\n");

	  // We can assume that the weight vector is stored in sparse
	  // format because the classifier has been built
	  for (int i = 0; i < m_sparseWeights.length; i++) {
	    if (m_sparseIndices[i] != (int)m_classIndex) {
	      if (printed > 0) {
		text.append(" + ");
	      } else {
		text.append("   ");
	      }
	      text.append(Utils.doubleToString(m_sparseWeights[i], 12, 4) +
			  " * ");
	      if (m_filterType == FILTER_STANDARDIZE) {
		text.append("(standardized) ");
	      } else if (m_filterType == FILTER_NORMALIZE) {
		text.append("(normalized) ");
	      }
	      if (!m_checksTurnedOff) {
		text.append(m_data.attribute(m_sparseIndices[i]).name()+"\n");
	      } else {
		text.append("attribute with index " + 
			    m_sparseIndices[i] +"\n");
	      }
	      printed++;
	    }
	  }
	} else {
	  for (int i = 0; i < m_alpha.length; i++) {
	    if (m_supportVectors.contains(i)) {
	      double val = m_alpha[i];
	      if (m_class[i] == 1) {
		if (printed > 0) {
		  text.append(" + ");
		}
	      } else {
		text.append(" - ");
	      }
	      text.append(Utils.doubleToString(val, 12, 4) 
			  + " * <");
	      for (int j = 0; j < m_data.numAttributes(); j++) {
		if (j != m_data.classIndex()) {
		  text.append(m_data.instance(i).toString(j));
		}
		if (j != m_data.numAttributes() - 1) {
		  text.append(" ");
		}
	      }
	      text.append("> * X]\n");
	      printed++;
	    }
	  }
	}
	if (m_b > 0) {
	  text.append(" - " + Utils.doubleToString(m_b, 12, 4));
	} else {
	  text.append(" + " + Utils.doubleToString(-m_b, 12, 4));
	}

	if (m_useRBF || m_exponent != 1.0) {
	  text.append("\n\nNumber of support vectors: " + 
		      m_supportVectors.numElements());
	}
	int numEval = 0;
	int numCacheHits = -1;
	if(m_kernel != null)
	{
	  numEval = m_kernel.numEvals();
	  numCacheHits = m_kernel.numCacheHits();
	}
	text.append("\n\nNumber of kernel evaluations: " + numEval);
	if (numCacheHits >= 0 && numEval > 0)
	{
		double hitRatio = 1 - numEval*1.0/(numCacheHits+numEval);
		text.append(" (" + Utils.doubleToString(hitRatio*100, 7, 3).trim() + "% cached)");
	}

      } catch (Exception e) {
	e.printStackTrace();

	return "Can't print BinarySMO classifier.";
      }
    
      return text.toString();
    }

    /**
     * Examines instance.
     *
     * @param i2 index of instance to examine
     * @return true if examination was successfull
     * @exception Exception if something goes wrong
     */
    protected boolean examineExample(int i2) throws Exception {
    
      double y2, alph2, F2;
      int i1 = -1;
    
      y2 = m_class[i2];
      alph2 = m_alpha[i2];
      if (m_I0.contains(i2)) {
	F2 = m_errors[i2];
      } else {
	F2 = SVMOutput(i2, m_data.instance(i2)) + m_b - y2;
	m_errors[i2] = F2;
      
	// Update thresholds
	if ((m_I1.contains(i2) || m_I2.contains(i2)) && (F2 < m_bUp)) {
	  m_bUp = F2; m_iUp = i2;
	} else if ((m_I3.contains(i2) || m_I4.contains(i2)) && (F2 > m_bLow)) {
	  m_bLow = F2; m_iLow = i2;
	}
      }

      // Check optimality using current bLow and bUp and, if
      // violated, find an index i1 to do joint optimization
      // with i2...
      boolean optimal = true;
      if (m_I0.contains(i2) || m_I1.contains(i2) || m_I2.contains(i2)) {
	if (m_bLow - F2 > 2 * m_tol) {
	  optimal = false; i1 = m_iLow;
	}
      }
      if (m_I0.contains(i2) || m_I3.contains(i2) || m_I4.contains(i2)) {
	if (F2 - m_bUp > 2 * m_tol) {
	  optimal = false; i1 = m_iUp;
	}
      }
      if (optimal) {
	return false;
      }

      // For i2 unbound choose the better i1...
      if (m_I0.contains(i2)) {
	if (m_bLow - F2 > F2 - m_bUp) {
	  i1 = m_iLow;
	} else {
	  i1 = m_iUp;
	}
      }
      if (i1 == -1) {
	throw new Exception("This should never happen!");
      }
      return takeStep(i1, i2, F2);
    }

    /**
     * Method solving for the Lagrange multipliers for
     * two instances.
     *
     * @param i1 index of the first instance
     * @param i2 index of the second instance
     * @return true if multipliers could be found
     * @exception Exception if something goes wrong
     */
    protected boolean takeStep(int i1, int i2, double F2) throws Exception {

      double alph1, alph2, y1, y2, F1, s, L, H, k11, k12, k22, eta,
	a1, a2, f1, f2, v1, v2, Lobj, Hobj, b1, b2, bOld;
      double C1 = m_C * m_data.instance(i1).weight();
      double C2 = m_C * m_data.instance(i2).weight();

      // Don't do anything if the two instances are the same
      if (i1 == i2) {
	return false;
      }

      // Initialize variables
      alph1 = m_alpha[i1]; alph2 = m_alpha[i2];
      y1 = m_class[i1]; y2 = m_class[i2];
      F1 = m_errors[i1];
      s = y1 * y2;

      // Find the constraints on a2
      if (y1 != y2) {
	L = Math.max(0, alph2 - alph1); 
	H = Math.min(C2, C1 + alph2 - alph1);
      } else {
	L = Math.max(0, alph1 + alph2 - C1);
	H = Math.min(C2, alph1 + alph2);
      }
      if (L >= H) {
	return false;
      }

      // Compute second derivative of objective function
      k11 = m_kernel.eval(i1, i1, m_data.instance(i1));
      k12 = m_kernel.eval(i1, i2, m_data.instance(i1));
      k22 = m_kernel.eval(i2, i2, m_data.instance(i2));
      eta = 2 * k12 - k11 - k22;

      // Check if second derivative is negative
      if (eta < 0) {

	// Compute unconstrained maximum
	a2 = alph2 - y2 * (F1 - F2) / eta;

	// Compute constrained maximum
	if (a2 < L) {
	  a2 = L;
	} else if (a2 > H) {
	  a2 = H;
	}
      } else {

	// Look at endpoints of diagonal
	f1 = SVMOutput(i1, m_data.instance(i1));
	f2 = SVMOutput(i2, m_data.instance(i2));
	v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12; 
	v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22; 
	double gamma = alph1 + s * alph2;
	Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) - 
	  0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L - 
	  y1 * (gamma - s * L) * v1 - y2 * L * v2;
	Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) - 
	  0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H - 
	  y1 * (gamma - s * H) * v1 - y2 * H * v2;
	if (Lobj > Hobj + m_eps) {
	  a2 = L;
	} else if (Lobj < Hobj - m_eps) {
	  a2 = H;
	} else {
	  a2 = alph2;
	}
      }
      if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) {
	return false;
      }
      
      // To prevent precision problems
      if (a2 > C2 - m_Del * C2) {
	a2 = C2;
      } else if (a2 <= m_Del * C2) {
	a2 = 0;
      }
      
      // Recompute a1
      a1 = alph1 + s * (alph2 - a2);
      
      // To prevent precision problems
      if (a1 > C1 - m_Del * C1) {
	a1 = C1;
      } else if (a1 <= m_Del * C1) {
	a1 = 0;
      }
      
      // Update sets
      if (a1 > 0) {
	m_supportVectors.insert(i1);
      } else {
	m_supportVectors.delete(i1);
      }
      if ((a1 > 0) && (a1 < C1)) {
	m_I0.insert(i1);
      } else {
	m_I0.delete(i1);
      }
      if ((y1 == 1) && (a1 == 0)) {
	m_I1.insert(i1);
      } else {
	m_I1.delete(i1);
      }
      if ((y1 == -1) && (a1 == C1)) {
	m_I2.insert(i1);
      } else {
	m_I2.delete(i1);
      }
      if ((y1 == 1) && (a1 == C1)) {
	m_I3.insert(i1);
      } else {
	m_I3.delete(i1);
      }
      if ((y1 == -1) && (a1 == 0)) {
	m_I4.insert(i1);
      } else {
	m_I4.delete(i1);
      }
      if (a2 > 0) {
	m_supportVectors.insert(i2);
      } else {
	m_supportVectors.delete(i2);
      }
      if ((a2 > 0) && (a2 < C2)) {
	m_I0.insert(i2);
      } else {
	m_I0.delete(i2);
      }
      if ((y2 == 1) && (a2 == 0)) {
	m_I1.insert(i2);
      } else {
	m_I1.delete(i2);
      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -