📄 svm.java
字号:
{
++nr_free1;
sum_free1 += G[i];
}
}
else
{
if(is_lower_bound(i))
ub2 = Math.min(ub2,G[i]);
else if(is_upper_bound(i))
lb2 = Math.max(lb2,G[i]);
else
{
++nr_free2;
sum_free2 += G[i];
}
}
}
double r1,r2;
if(nr_free1 > 0)
r1 = sum_free1/nr_free1;
else
r1 = (ub1+lb1)/2;
if(nr_free2 > 0)
r2 = sum_free2/nr_free2;
else
r2 = (ub2+lb2)/2;
si.r = (r1+r2)/2;
return (r1-r2)/2;
}
}
//
// Q matrices for various formulations
//
class SVC_Q extends Kernel
{
private final byte[] y;
private final Cache cache;
private final float[] QD;
SVC_Q(svm_problem prob, svm_parameter param, byte[] y_)
{
super(prob.l, prob.x, param);
y = (byte[])y_.clone();
cache = new Cache(prob.l,(int)(param.cache_size*(1<<20)));
QD = new float[prob.l];
for(int i=0;i<prob.l;i++)
QD[i]= (float)kernel_function(i,i);
}
float[] get_Q(int i, int len)
{
float[][] data = new float[1][];
int start;
if((start = cache.get_data(i,data,len)) < len)
{
for(int j=start;j<len;j++)
data[0][j] = (float)(y[i]*y[j]*kernel_function(i,j));
}
return data[0];
}
float[] get_QD()
{
return QD;
}
void swap_index(int i, int j)
{
cache.swap_index(i,j);
super.swap_index(i,j);
do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false);
do {float _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
}
}
class ONE_CLASS_Q extends Kernel
{
private final Cache cache;
private final float[] QD;
ONE_CLASS_Q(svm_problem prob, svm_parameter param)
{
super(prob.l, prob.x, param);
cache = new Cache(prob.l,(int)(param.cache_size*(1<<20)));
QD = new float[prob.l];
for(int i=0;i<prob.l;i++)
QD[i]= (float)kernel_function(i,i);
}
float[] get_Q(int i, int len)
{
float[][] data = new float[1][];
int start;
if((start = cache.get_data(i,data,len)) < len)
{
for(int j=start;j<len;j++)
data[0][j] = (float)kernel_function(i,j);
}
return data[0];
}
float[] get_QD()
{
return QD;
}
void swap_index(int i, int j)
{
cache.swap_index(i,j);
super.swap_index(i,j);
do {float _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
}
}
class SVR_Q extends Kernel
{
private final int l;
private final Cache cache;
private final byte[] sign;
private final int[] index;
private int next_buffer;
private float[][] buffer;
private final float[] QD;
SVR_Q(svm_problem prob, svm_parameter param)
{
super(prob.l, prob.x, param);
l = prob.l;
cache = new Cache(l,(int)(param.cache_size*(1<<20)));
QD = new float[2*l];
sign = new byte[2*l];
index = new int[2*l];
for(int k=0;k<l;k++)
{
sign[k] = 1;
sign[k+l] = -1;
index[k] = k;
index[k+l] = k;
QD[k] = (float)kernel_function(k,k);
QD[k+l] = QD[k];
}
buffer = new float[2][2*l];
next_buffer = 0;
}
void swap_index(int i, int j)
{
do {byte _=sign[i]; sign[i]=sign[j]; sign[j]=_;} while(false);
do {int _=index[i]; index[i]=index[j]; index[j]=_;} while(false);
do {float _=QD[i]; QD[i]=QD[j]; QD[j]=_;} while(false);
}
float[] get_Q(int i, int len)
{
float[][] data = new float[1][];
int real_i = index[i];
if(cache.get_data(real_i,data,l) < l)
{
for(int j=0;j<l;j++)
data[0][j] = (float)kernel_function(real_i,j);
}
// reorder and copy
float buf[] = buffer[next_buffer];
next_buffer = 1 - next_buffer;
byte si = sign[i];
for(int j=0;j<len;j++)
buf[j] = si * sign[j] * data[0][index[j]];
return buf;
}
float[] get_QD()
{
return QD;
}
}
public class svm {
//
// construct and solve various formulations
//
private static void solve_c_svc(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si,
double Cp, double Cn)
{
int l = prob.l;
double[] minus_ones = new double[l];
byte[] y = new byte[l];
int i;
for(i=0;i<l;i++)
{
alpha[i] = 0;
minus_ones[i] = -1;
if(prob.y[i] > 0) y[i] = +1; else y[i]=-1;
}
Solver s = new Solver();
s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y,
alpha, Cp, Cn, param.eps, si, param.shrinking);
double sum_alpha=0;
for(i=0;i<l;i++)
sum_alpha += alpha[i];
if (Cp==Cn)
System.out.print("nu = "+sum_alpha/(Cp*prob.l)+"\n");
for(i=0;i<l;i++)
alpha[i] *= y[i];
}
private static void solve_nu_svc(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si)
{
int i;
int l = prob.l;
double nu = param.nu;
byte[] y = new byte[l];
for(i=0;i<l;i++)
if(prob.y[i]>0)
y[i] = +1;
else
y[i] = -1;
double sum_pos = nu*l/2;
double sum_neg = nu*l/2;
for(i=0;i<l;i++)
if(y[i] == +1)
{
alpha[i] = Math.min(1.0,sum_pos);
sum_pos -= alpha[i];
}
else
{
alpha[i] = Math.min(1.0,sum_neg);
sum_neg -= alpha[i];
}
double[] zeros = new double[l];
for(i=0;i<l;i++)
zeros[i] = 0;
Solver_NU s = new Solver_NU();
s.Solve(l, new SVC_Q(prob,param,y), zeros, y,
alpha, 1.0, 1.0, param.eps, si, param.shrinking);
double r = si.r;
System.out.print("C = "+1/r+"\n");
for(i=0;i<l;i++)
alpha[i] *= y[i]/r;
si.rho /= r;
si.obj /= (r*r);
si.upper_bound_p = 1/r;
si.upper_bound_n = 1/r;
}
private static void solve_one_class(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si)
{
int l = prob.l;
double[] zeros = new double[l];
byte[] ones = new byte[l];
int i;
int n = (int)(param.nu*prob.l); // # of alpha's at upper bound
for(i=0;i<n;i++)
alpha[i] = 1;
if(n<prob.l)
alpha[n] = param.nu * prob.l - n;
for(i=n+1;i<l;i++)
alpha[i] = 0;
for(i=0;i<l;i++)
{
zeros[i] = 0;
ones[i] = 1;
}
Solver s = new Solver();
s.Solve(l, new ONE_CLASS_Q(prob,param), zeros, ones,
alpha, 1.0, 1.0, param.eps, si, param.shrinking);
}
private static void solve_epsilon_svr(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si)
{
int l = prob.l;
double[] alpha2 = new double[2*l];
double[] linear_term = new double[2*l];
byte[] y = new byte[2*l];
int i;
for(i=0;i<l;i++)
{
alpha2[i] = 0;
linear_term[i] = param.p - prob.y[i];
y[i] = 1;
alpha2[i+l] = 0;
linear_term[i+l] = param.p + prob.y[i];
y[i+l] = -1;
}
Solver s = new Solver();
s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
alpha2, param.C, param.C, param.eps, si, param.shrinking);
double sum_alpha = 0;
for(i=0;i<l;i++)
{
alpha[i] = alpha2[i] - alpha2[i+l];
sum_alpha += Math.abs(alpha[i]);
}
System.out.print("nu = "+sum_alpha/(param.C*l)+"\n");
}
private static void solve_nu_svr(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si)
{
int l = prob.l;
double C = param.C;
double[] alpha2 = new double[2*l];
double[] linear_term = new double[2*l];
byte[] y = new byte[2*l];
int i;
double sum = C * param.nu * l / 2;
for(i=0;i<l;i++)
{
alpha2[i] = alpha2[i+l] = Math.min(sum,C);
sum -= alpha2[i];
linear_term[i] = - prob.y[i];
y[i] = 1;
linear_term[i+l] = prob.y[i];
y[i+l] = -1;
}
Solver_NU s = new Solver_NU();
s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
alpha2, C, C, param.eps, si, param.shrinking);
System.out.print("epsilon = "+(-si.r)+"\n");
for(i=0;i<l;i++)
alpha[i] = alpha2[i] - alpha2[i+l];
}
//
// decision_function
//
static class decision_function
{
double[] alpha;
double rho;
};
static decision_function svm_train_one(
svm_problem prob, svm_parameter param,
double Cp, double Cn)
{
double[] alpha = new double[prob.l];
Solver.SolutionInfo si = new Solver.SolutionInfo();
switch(param.svm_type)
{
case svm_parameter.C_SVC:
solve_c_svc(prob,param,alpha,si,Cp,Cn);
break;
case svm_parameter.NU_SVC:
solve_nu_svc(prob,param,alpha,si);
break;
case svm_parameter.ONE_CLASS:
solve_one_class(prob,param,alpha,si);
break;
case svm_parameter.EPSILON_SVR:
solve_epsilon_svr(prob,param,alpha,si);
break;
case svm_parameter.NU_SVR:
solve_nu_svr(prob,param,alpha,si);
break;
}
System.out.print("obj = "+si.obj+", rho = "+si.rho+"\n");
// output SVs
int nSV = 0;
int nBSV = 0;
for(int i=0;i<prob.l;i++)
{
if(Math.abs(alpha[i]) > 0)
{
++nSV;
if(prob.y[i] > 0)
{
if(Math.abs(alpha[i]) >= si.upper_bound_p)
++nBSV;
}
else
{
if(Math.abs(alpha[i]) >= si.upper_bound_n)
++nBSV;
}
}
}
System.out.print("nSV = "+nSV+", nBSV = "+nBSV+"\n");
decision_function f = new decision_function();
f.alpha = alpha;
f.rho = si.rho;
return f;
}
// Platt's binary SVM Probablistic Output: an improvement from Lin et al.
private static void sigmoid_train(int l, double[] dec_values, double[] labels,
double[] probAB)
{
double A, B;
double prior1=0, prior0 = 0;
int i;
for (i=0;i<l;i++)
if (labels[i] > 0) prior1+=1;
else prior0+=1;
int max_iter=100; // Maximal number of iterations
double min_step=1e-10; // Minimal step taken in line search
double sigma=1e-3; // For numerically strict PD of Hessian
double eps=1e-5;
double hiTarget=(prior1+1.0)/(prior1+2.0);
double loTarget=1/(prior0+2.0);
double[] t= new double[l];
double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
double newA,newB,newf,d1,d2;
int iter;
// Initial Point and Initial Fun Value
A=0.0; B=Math.log((prior0+1.0)/(prior1+1.0));
double fval = 0.0;
for (i=0;i<l;i++)
{
if (labels[i]>0) t[i]=hiTarget;
else t[i]=loTarget;
fApB = dec_values[i]*A+B;
if (fApB>=0)
fval += t[i]*fApB + Math.log(1+Math.exp(-fApB));
else
fval += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
}
for (iter=0;iter<max_iter;iter++)
{
// Update Gradient and Hessian (use H' = H + sigma I)
h11=sigma; // numerically ensures strict PD
h22=sigma;
h21=0.0;g1=0.0;g2=0.0;
for (i=0;i<l;i++)
{
fApB = dec_values[i]*A+B;
if (fApB >= 0)
{
p=Math.exp(-fApB)/(1.0+Math.exp(-fApB));
q=1.0/(1.0+Math.exp(-fApB));
}
else
{
p=1.0/(1.0+Math.exp(fApB));
q=Math.exp(fApB)/(1.0+Math.exp(fApB));
}
d2=p*q;
h11+=dec_values[i]*dec_values[i]*d2;
h22+=d2;
h21+=dec_values[i]*d2;
d1=t[i]-p;
g1+=dec_values[i]*d1;
g2+=d1;
}
// Stopping Criteria
if (Math.abs(g1)<eps && Math.abs(g2)<eps)
break;
// Finding Newton direction: -inv(H') * g
det=h11*h22-h21*h21;
dA=-(h22*g1 - h21 * g2) / det;
dB=-(-h21*g1+ h11 * g2) / det;
gd=g1*dA+g2*dB;
stepsize = 1; // Line Search
while (stepsize >= min_step)
{
newA = A + stepsize * dA;
newB = B + stepsize * dB;
// New function value
newf = 0.0;
for (i=0;i<l;i++)
{
fApB = dec_values[i]*newA+newB;
if (fApB >= 0)
newf += t[i]*fApB + Math.log(1+Math.exp(-fApB));
else
newf += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB));
}
// Check sufficient decrease
if (newf<fval+0.0001*stepsize*gd)
{
A=newA;B=newB;fval=newf;
break;
}
else
stepsize = stepsize / 2.0;
}
if (stepsize < min_step)
{
System.err.print("Line search fails in two-class probability estimates\n");
break;
}
}
if (iter>=max_iter)
System.err.print("Reaching maximal iterations in two-class probability estimates\n");
probAB[0]=A;probAB[1]=B;
}
private static double sigmoid_predict(double decision_value, double A, double B)
{
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -