📄 solver.cs
字号:
//Copyright (C) 2007 Matthew Johnson
//This program is free software; you can redistribute it and/or modify
//it under the terms of the GNU General Public License as published by
//the Free Software Foundation; either version 2 of the License, or
//(at your option) any later version.
//This program is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
//GNU General Public License for more details.
//You should have received a copy of the GNU General Public License along
//with this program; if not, write to the Free Software Foundation, Inc.,
//51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
using System;
using System.Collections.Generic;
using System.Text;
using System.IO;
using System.Diagnostics;
namespace SVM
{
//
// Kernel evaluation
//
// the static method k_function is for doing single kernel evaluation
// the constructor of Kernel prepares to calculate the l*l kernel matrix
// the member function get_Q is for getting one column from the Q Matrix
//
internal abstract class QMatrix
{
public abstract float[] get_Q(int column, int len);
public abstract float[] get_QD();
public abstract void swap_index(int i, int j);
}
internal abstract class Kernel : QMatrix
{
private Node[][] _x;
private double[] _x_square;
// Parameter
private KernelType kernel_type;
private int degree;
private double gamma;
private double coef0;
public override void swap_index(int i, int j)
{
do { Node[] _ = _x[i]; _x[i] = _x[j]; _x[j] = _; } while (false);
if (_x_square != null) do { double _ = _x_square[i]; _x_square[i] = _x_square[j]; _x_square[j] = _; } while (false);
}
private static double powi(double baseValue, int times)
{
double tmp = baseValue, ret = 1.0;
for (int t = times; t > 0; t /= 2)
{
if (t % 2 == 1) ret *= tmp;
tmp = tmp * tmp;
}
return ret;
}
private static double tanh(double x)
{
double e = Math.Exp(x);
return 1.0 - 2.0 / (e * e + 1);
}
public double kernel_function(int i, int j)
{
switch (kernel_type)
{
case KernelType.LINEAR:
return dot(_x[i], _x[j]);
case KernelType.POLY:
return powi(gamma * dot(_x[i], _x[j]) + coef0, degree);
case KernelType.RBF:
return Math.Exp(-gamma * (_x_square[i] + _x_square[j] - 2 * dot(_x[i], _x[j])));
case KernelType.SIGMOID:
return tanh(gamma * dot(_x[i], _x[j]) + coef0);
case KernelType.PRECOMPUTED:
return _x[i][(int)(_x[j][0].Value)].Value;
default:
return 0;
}
}
public Kernel(int l, Node[][] x_, Parameter param)
{
this.kernel_type = param.KernelType;
this.degree = param.Degree;
this.gamma = param.Gamma;
this.coef0 = param.Coefficient0;
_x = (Node[][])x_.Clone();
if (kernel_type == KernelType.RBF)
{
_x_square = new double[l];
for (int i = 0; i < l; i++)
_x_square[i] = dot(_x[i], _x[i]);
}
else _x_square = null;
}
public static double dot(Node[] x, Node[] y)
{
double sum = 0;
int xlen = x.Length;
int ylen = y.Length;
int i = 0;
int j = 0;
while (i < xlen && j < ylen)
{
if (x[i].Index == y[j].Index)
sum += x[i++].Value * y[j++].Value;
else
{
if (x[i].Index > y[j].Index)
++j;
else
++i;
}
}
return sum;
}
public static double k_function(Node[] x, Node[] y, Parameter param)
{
switch (param.KernelType)
{
case KernelType.LINEAR:
return dot(x, y);
case KernelType.POLY:
return powi(param.Gamma * dot(x, y) + param.Coefficient0, param.Degree);
case KernelType.RBF:
{
double sum = 0;
int xlen = x.Length;
int ylen = y.Length;
int i = 0;
int j = 0;
while (i < xlen && j < ylen)
{
if (x[i].Index == y[j].Index)
{
double d = x[i++].Value - y[j++].Value;
sum += d * d;
}
else if (x[i].Index > y[j].Index)
{
sum += y[j].Value * y[j].Value;
++j;
}
else
{
sum += x[i].Value * x[i].Value;
++i;
}
}
while (i < xlen)
{
sum += x[i].Value * x[i].Value;
++i;
}
while (j < ylen)
{
sum += y[j].Value * y[j].Value;
++j;
}
return Math.Exp(-param.Gamma * sum);
}
case KernelType.SIGMOID:
return tanh(param.Gamma * dot(x, y) + param.Coefficient0);
case KernelType.PRECOMPUTED:
return x[(int)(y[0].Value)].Value;
default:
return 0;
}
}
}
// An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
// Solves:
//
// min 0.5(\alpha^T Q \alpha) + p^T \alpha
//
// y^T \alpha = \delta
// y_i = +1 or -1
// 0 <= alpha_i <= Cp for y_i = 1
// 0 <= alpha_i <= Cn for y_i = -1
//
// Given:
//
// Q, p, y, Cp, Cn, and an initial feasible point \alpha
// l is the size of vectors and matrices
// eps is the stopping tolerance
//
// solution will be put in \alpha, objective value will be put in obj
//
internal class Solver
{
protected int active_size;
protected short[] y;
protected double[] G; // gradient of objective function
protected const byte LOWER_BOUND = 0;
protected const byte UPPER_BOUND = 1;
protected const byte FREE = 2;
protected byte[] alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
protected double[] alpha;
protected QMatrix Q;
protected float[] QD;
protected double eps;
protected double Cp, Cn;
protected double[] p;
protected int[] active_set;
protected double[] G_bar; // gradient, if we treat free variables as 0
protected int l;
protected bool unshrinked; // XXX
protected const double INF = double.PositiveInfinity;
protected double get_C(int i)
{
return (y[i] > 0) ? Cp : Cn;
}
protected void update_alpha_status(int i)
{
if (alpha[i] >= get_C(i))
alpha_status[i] = UPPER_BOUND;
else if (alpha[i] <= 0)
alpha_status[i] = LOWER_BOUND;
else alpha_status[i] = FREE;
}
protected bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
protected bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
protected bool is_free(int i) { return alpha_status[i] == FREE; }
// java: information about solution except alpha,
// because we cannot return multiple values otherwise...
internal class SolutionInfo
{
public double obj;
public double rho;
public double upper_bound_p;
public double upper_bound_n;
public double r; // for Solver_NU
}
protected void swap_index(int i, int j)
{
Q.swap_index(i, j);
do { short _ = y[i]; y[i] = y[j]; y[j] = _; } while (false);
do { double _ = G[i]; G[i] = G[j]; G[j] = _; } while (false);
do { byte _ = alpha_status[i]; alpha_status[i] = alpha_status[j]; alpha_status[j] = _; } while (false);
do { double _ = alpha[i]; alpha[i] = alpha[j]; alpha[j] = _; } while (false);
do { double _ = p[i]; p[i] = p[j]; p[j] = _; } while (false);
do { int _ = active_set[i]; active_set[i] = active_set[j]; active_set[j] = _; } while (false);
do { double _ = G_bar[i]; G_bar[i] = G_bar[j]; G_bar[j] = _; } while (false);
}
protected void reconstruct_gradient()
{
// reconstruct inactive elements of G from G_bar and free variables
if (active_size == l) return;
int i;
for (i = active_size; i < l; i++)
G[i] = G_bar[i] + p[i];
for (i = 0; i < active_size; i++)
if (is_free(i))
{
float[] Q_i = Q.get_Q(i, l);
double alpha_i = alpha[i];
for (int j = active_size; j < l; j++)
G[j] += alpha_i * Q_i[j];
}
}
public virtual void Solve(int l, QMatrix Q, double[] p_, short[] y_,
double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, bool shrinking)
{
this.l = l;
this.Q = Q;
QD = Q.get_QD();
p = (double[])p_.Clone();
y = (short[])y_.Clone();
alpha = (double[])alpha_.Clone();
this.Cp = Cp;
this.Cn = Cn;
this.eps = eps;
this.unshrinked = false;
// initialize alpha_status
{
alpha_status = new byte[l];
for (int i = 0; i < l; i++)
update_alpha_status(i);
}
// initialize active set (for shrinking)
{
active_set = new int[l];
for (int i = 0; i < l; i++)
active_set[i] = i;
active_size = l;
}
// initialize gradient
{
G = new double[l];
G_bar = new double[l];
int i;
for (i = 0; i < l; i++)
{
G[i] = p[i];
G_bar[i] = 0;
}
for (i = 0; i < l; i++)
if (!is_lower_bound(i))
{
float[] Q_i = Q.get_Q(i, l);
double alpha_i = alpha[i];
int j;
for (j = 0; j < l; j++)
G[j] += alpha_i * Q_i[j];
if (is_upper_bound(i))
for (j = 0; j < l; j++)
G_bar[j] += get_C(i) * Q_i[j];
}
}
// optimization step
int iter = 0;
int counter = Math.Min(l, 1000) + 1;
int[] working_set = new int[2];
while (true)
{
// show progress and do shrinking
if (--counter == 0)
{
counter = Math.Min(l, 1000);
if (shrinking) do_shrinking();
Debug.Write(".");
}
if (select_working_set(working_set) != 0)
{
// reconstruct the whole gradient
reconstruct_gradient();
// reset active set size and check
active_size = l;
Debug.Write("*");
if (select_working_set(working_set) != 0)
break;
else
counter = 1; // do shrinking next iteration
}
int i = working_set[0];
int j = working_set[1];
++iter;
// update alpha[i] and alpha[j], handle bounds carefully
float[] Q_i = Q.get_Q(i, active_size);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -