📄 smo.java
字号:
if (examineAll) {
examineAll = false;
} else if (numChanged == 0) {
examineAll = true;
}
}
// Set threshold
m_b = (m_bLow + m_bUp) / 2.0;
// Save memory
m_kernel.clean();
m_errors = null;
m_I0 = m_I1 = m_I2 = m_I3 = m_I4 = null;
// If machine is linear, delete training data
// and store weight vector in sparse format
if (!m_useRBF && m_exponent == 1.0) {
// We don't need to store the set of support vectors
m_supportVectors = null;
// We don't need to store the class values either
m_class = null;
// Clean out training data
if (!m_checksTurnedOff) {
m_data = new Instances(m_data, 0);
} else {
m_data = null;
}
// Convert weight vector
double[] sparseWeights = new double[m_weights.length];
int[] sparseIndices = new int[m_weights.length];
int counter = 0;
for (int i = 0; i < m_weights.length; i++) {
if (m_weights[i] != 0.0) {
sparseWeights[counter] = m_weights[i];
sparseIndices[counter] = i;
counter++;
}
}
m_sparseWeights = new double[counter];
m_sparseIndices = new int[counter];
System.arraycopy(sparseWeights, 0, m_sparseWeights, 0, counter);
System.arraycopy(sparseIndices, 0, m_sparseIndices, 0, counter);
// Clean out weight vector
m_weights = null;
// We don't need the alphas in the linear case
m_alpha = null;
}
// Fit sigmoid if requested
if (fitLogistic) {
fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed));
}
}
/**
* Computes SVM output for given instance.
*
* @param index the instance for which output is to be computed
* @param inst the instance
* @return the output of the SVM for the given instance
*/
protected double SVMOutput(int index, Instance inst) throws Exception {
double result = 0;
// Is the machine linear?
if (!m_useRBF && m_exponent == 1.0) {
// Is weight vector stored in sparse format?
if (m_sparseWeights == null) {
int n1 = inst.numValues();
for (int p = 0; p < n1; p++) {
if (inst.index(p) != m_classIndex) {
result += m_weights[inst.index(p)] * inst.valueSparse(p);
}
}
} else {
int n1 = inst.numValues(); int n2 = m_sparseWeights.length;
for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
int ind1 = inst.index(p1);
int ind2 = m_sparseIndices[p2];
if (ind1 == ind2) {
if (ind1 != m_classIndex) {
result += inst.valueSparse(p1) * m_sparseWeights[p2];
}
p1++; p2++;
} else if (ind1 > ind2) {
p2++;
} else {
p1++;
}
}
}
} else {
for (int i = m_supportVectors.getNext(-1); i != -1;
i = m_supportVectors.getNext(i)) {
result += m_class[i] * m_alpha[i] * m_kernel.eval(index, i, inst);
}
}
result -= m_b;
return result;
}
/**
* Prints out the classifier.
*
* @return a description of the classifier as a string
*/
public String toString() {
StringBuffer text = new StringBuffer();
int printed = 0;
if ((m_alpha == null) && (m_sparseWeights == null)) {
return "BinarySMO: No model built yet.\n";
}
try {
text.append("BinarySMO\n\n");
// If machine linear, print weight vector
if (!m_useRBF && m_exponent == 1.0) {
text.append("Machine linear: showing attribute weights, ");
text.append("not support vectors.\n\n");
// We can assume that the weight vector is stored in sparse
// format because the classifier has been built
for (int i = 0; i < m_sparseWeights.length; i++) {
if (m_sparseIndices[i] != (int)m_classIndex) {
if (printed > 0) {
text.append(" + ");
} else {
text.append(" ");
}
text.append(Utils.doubleToString(m_sparseWeights[i], 12, 4) +
" * ");
if (m_filterType == FILTER_STANDARDIZE) {
text.append("(standardized) ");
} else if (m_filterType == FILTER_NORMALIZE) {
text.append("(normalized) ");
}
if (!m_checksTurnedOff) {
text.append(m_data.attribute(m_sparseIndices[i]).name()+"\n");
} else {
text.append("attribute with index " +
m_sparseIndices[i] +"\n");
}
printed++;
}
}
} else {
for (int i = 0; i < m_alpha.length; i++) {
if (m_supportVectors.contains(i)) {
double val = m_alpha[i];
if (m_class[i] == 1) {
if (printed > 0) {
text.append(" + ");
}
} else {
text.append(" - ");
}
text.append(Utils.doubleToString(val, 12, 4)
+ " * <");
for (int j = 0; j < m_data.numAttributes(); j++) {
if (j != m_data.classIndex()) {
text.append(m_data.instance(i).toString(j));
}
if (j != m_data.numAttributes() - 1) {
text.append(" ");
}
}
text.append("> * X]\n");
printed++;
}
}
}
if (m_b > 0) {
text.append(" - " + Utils.doubleToString(m_b, 12, 4));
} else {
text.append(" + " + Utils.doubleToString(-m_b, 12, 4));
}
if (m_useRBF || m_exponent != 1.0) {
text.append("\n\nNumber of support vectors: " +
m_supportVectors.numElements());
}
int numEval = 0;
int numCacheHits = -1;
if(m_kernel != null)
{
numEval = m_kernel.numEvals();
numCacheHits = m_kernel.numCacheHits();
}
text.append("\n\nNumber of kernel evaluations: " + numEval);
if (numCacheHits >= 0 && numEval > 0)
{
double hitRatio = 1 - numEval*1.0/(numCacheHits+numEval);
text.append(" (" + Utils.doubleToString(hitRatio*100, 7, 3).trim() + "% cached)");
}
} catch (Exception e) {
e.printStackTrace();
return "Can't print BinarySMO classifier.";
}
return text.toString();
}
/**
* Examines instance.
*
* @param i2 index of instance to examine
* @return true if examination was successfull
* @exception Exception if something goes wrong
*/
protected boolean examineExample(int i2) throws Exception {
double y2, alph2, F2;
int i1 = -1;
y2 = m_class[i2];
alph2 = m_alpha[i2];
if (m_I0.contains(i2)) {
F2 = m_errors[i2];
} else {
F2 = SVMOutput(i2, m_data.instance(i2)) + m_b - y2;
m_errors[i2] = F2;
// Update thresholds
if ((m_I1.contains(i2) || m_I2.contains(i2)) && (F2 < m_bUp)) {
m_bUp = F2; m_iUp = i2;
} else if ((m_I3.contains(i2) || m_I4.contains(i2)) && (F2 > m_bLow)) {
m_bLow = F2; m_iLow = i2;
}
}
// Check optimality using current bLow and bUp and, if
// violated, find an index i1 to do joint optimization
// with i2...
boolean optimal = true;
if (m_I0.contains(i2) || m_I1.contains(i2) || m_I2.contains(i2)) {
if (m_bLow - F2 > 2 * m_tol) {
optimal = false; i1 = m_iLow;
}
}
if (m_I0.contains(i2) || m_I3.contains(i2) || m_I4.contains(i2)) {
if (F2 - m_bUp > 2 * m_tol) {
optimal = false; i1 = m_iUp;
}
}
if (optimal) {
return false;
}
// For i2 unbound choose the better i1...
if (m_I0.contains(i2)) {
if (m_bLow - F2 > F2 - m_bUp) {
i1 = m_iLow;
} else {
i1 = m_iUp;
}
}
if (i1 == -1) {
throw new Exception("This should never happen!");
}
return takeStep(i1, i2, F2);
}
/**
* Method solving for the Lagrange multipliers for
* two instances.
*
* @param i1 index of the first instance
* @param i2 index of the second instance
* @return true if multipliers could be found
* @exception Exception if something goes wrong
*/
protected boolean takeStep(int i1, int i2, double F2) throws Exception {
double alph1, alph2, y1, y2, F1, s, L, H, k11, k12, k22, eta,
a1, a2, f1, f2, v1, v2, Lobj, Hobj, b1, b2, bOld;
double C1 = m_C * m_data.instance(i1).weight();
double C2 = m_C * m_data.instance(i2).weight();
// Don't do anything if the two instances are the same
if (i1 == i2) {
return false;
}
// Initialize variables
alph1 = m_alpha[i1]; alph2 = m_alpha[i2];
y1 = m_class[i1]; y2 = m_class[i2];
F1 = m_errors[i1];
s = y1 * y2;
// Find the constraints on a2
if (y1 != y2) {
L = Math.max(0, alph2 - alph1);
H = Math.min(C2, C1 + alph2 - alph1);
} else {
L = Math.max(0, alph1 + alph2 - C1);
H = Math.min(C2, alph1 + alph2);
}
if (L >= H) {
return false;
}
// Compute second derivative of objective function
k11 = m_kernel.eval(i1, i1, m_data.instance(i1));
k12 = m_kernel.eval(i1, i2, m_data.instance(i1));
k22 = m_kernel.eval(i2, i2, m_data.instance(i2));
eta = 2 * k12 - k11 - k22;
// Check if second derivative is negative
if (eta < 0) {
// Compute unconstrained maximum
a2 = alph2 - y2 * (F1 - F2) / eta;
// Compute constrained maximum
if (a2 < L) {
a2 = L;
} else if (a2 > H) {
a2 = H;
}
} else {
// Look at endpoints of diagonal
f1 = SVMOutput(i1, m_data.instance(i1));
f2 = SVMOutput(i2, m_data.instance(i2));
v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12;
v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22;
double gamma = alph1 + s * alph2;
Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) -
0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L -
y1 * (gamma - s * L) * v1 - y2 * L * v2;
Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) -
0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H -
y1 * (gamma - s * H) * v1 - y2 * H * v2;
if (Lobj > Hobj + m_eps) {
a2 = L;
} else if (Lobj < Hobj - m_eps) {
a2 = H;
} else {
a2 = alph2;
}
}
if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) {
return false;
}
// To prevent precision problems
if (a2 > C2 - m_Del * C2) {
a2 = C2;
} else if (a2 <= m_Del * C2) {
a2 = 0;
}
// Recompute a1
a1 = alph1 + s * (alph2 - a2);
// To prevent precision problems
if (a1 > C1 - m_Del * C1) {
a1 = C1;
} else if (a1 <= m_Del * C1) {
a1 = 0;
}
// Update sets
if (a1 > 0) {
m_supportVectors.insert(i1);
} else {
m_supportVectors.delete(i1);
}
if ((a1 > 0) && (a1 < C1)) {
m_I0.insert(i1);
} else {
m_I0.delete(i1);
}
if ((y1 == 1) && (a1 == 0)) {
m_I1.insert(i1);
} else {
m_I1.delete(i1);
}
if ((y1 == -1) && (a1 == C1)) {
m_I2.insert(i1);
} else {
m_I2.delete(i1);
}
if ((y1 == 1) && (a1 == C1)) {
m_I3.insert(i1);
} else {
m_I3.delete(i1);
}
if ((y1 == -1) && (a1 == 0)) {
m_I4.insert(i1);
} else {
m_I4.delete(i1);
}
if (a2 > 0) {
m_supportVectors.insert(i2);
} else {
m_supportVectors.delete(i2);
}
if ((a2 > 0) && (a2 < C2)) {
m_I0.insert(i2);
} else {
m_I0.delete(i2);
}
if ((y2 == 1) && (a2 == 0)) {
m_I1.insert(i2);
} else {
m_I1.delete(i2);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -