📄 svm.java
字号:
int i;
for(i=0;i<l;i++)
{
alpha2[i] = 0;
linear_term[i] = param.p - prob.y[i];
y[i] = 1;
alpha2[i+l] = 0;
linear_term[i+l] = param.p + prob.y[i];
y[i+l] = -1;
}
Solver s = new Solver();
s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
alpha2, param.C, param.C, param.eps, si, param.shrinking);
double sum_alpha = 0;
for(i=0;i<l;i++)
{
alpha[i] = alpha2[i] - alpha2[i+l];
sum_alpha += Math.abs(alpha[i]);
}
System.out.print("nu = "+sum_alpha/(param.C*l)+"\n");
}
private static void solve_nu_svr(svm_problem prob, svm_parameter param,
double[] alpha, Solver.SolutionInfo si)
{
if(param.nu < 0 || param.nu > 1)
{
System.err.print("specified nu is out of range\n");
System.exit(1);
}
int l = prob.l;
double C = param.C;
double[] alpha2 = new double[2*l];
double[] linear_term = new double[2*l];
byte[] y = new byte[2*l];
int i;
double sum = C * param.nu * l / 2;
for(i=0;i<l;i++)
{
alpha2[i] = alpha2[i+l] = Math.min(sum,C);
sum -= alpha2[i];
linear_term[i] = - prob.y[i];
y[i] = 1;
linear_term[i+l] = prob.y[i];
y[i+l] = -1;
}
Solver_NU s = new Solver_NU();
s.Solve(2*l, new SVR_Q(prob,param), linear_term, y,
alpha2, C, C, param.eps, si, param.shrinking);
System.out.print("epsilon = "+(-si.r)+"\n");
for(i=0;i<l;i++)
alpha[i] = alpha2[i] - alpha2[i+l];
}
//
// decision_function
//
static class decision_function
{
double[] alpha;
double rho;
};
static decision_function svm_train_one(
svm_problem prob, svm_parameter param,
double Cp, double Cn)
{
double[] alpha = new double[prob.l];
Solver.SolutionInfo si = new Solver.SolutionInfo();
switch(param.svm_type)
{
case svm_parameter.C_SVC:
solve_c_svc(prob,param,alpha,si,Cp,Cn);
break;
case svm_parameter.NU_SVC:
solve_nu_svc(prob,param,alpha,si);
break;
case svm_parameter.ONE_CLASS:
solve_one_class(prob,param,alpha,si);
break;
case svm_parameter.EPSILON_SVR:
solve_epsilon_svr(prob,param,alpha,si);
break;
case svm_parameter.NU_SVR:
solve_nu_svr(prob,param,alpha,si);
break;
}
System.out.print("obj = "+si.obj+", rho = "+si.rho+"\n");
// output SVs
int nSV = 0;
int nBSV = 0;
for(int i=0;i<prob.l;i++)
{
if(Math.abs(alpha[i]) > 0)
{
++nSV;
if(prob.y[i] > 0)
{
if(Math.abs(alpha[i]) >= si.upper_bound_p)
++nBSV;
}
else
{
if(Math.abs(alpha[i]) >= si.upper_bound_n)
++nBSV;
}
}
}
System.out.print("nSV = "+nSV+", nBSV = "+nBSV+"\n");
decision_function f = new decision_function();
f.alpha = alpha;
f.rho = si.rho;
return f;
}
//
// Interface functions
//
public static svm_model svm_train(svm_problem prob, svm_parameter param)
{
svm_model model = new svm_model();
model.param = param;
if(param.svm_type == svm_parameter.ONE_CLASS ||
param.svm_type == svm_parameter.EPSILON_SVR ||
param.svm_type == svm_parameter.NU_SVR)
{
// regression or one-class-svm
model.nr_class = 2;
model.label = null;
model.nSV = null;
model.sv_coef = new double[1][];
decision_function f = svm_train_one(prob,param,0,0);
model.rho = new double[1];
model.rho[0] = f.rho;
int nSV = 0;
int i;
for(i=0;i<prob.l;i++)
if(Math.abs(f.alpha[i]) > 0) ++nSV;
model.l = nSV;
model.SV = new svm_node[nSV][];
model.sv_coef[0] = new double[nSV];
int j = 0;
for(i=0;i<prob.l;i++)
if(Math.abs(f.alpha[i]) > 0)
{
model.SV[j] = prob.x[i];
model.sv_coef[0][j] = f.alpha[i];
++j;
}
}
else
{
// classification
// find out the number of classes
int l = prob.l;
int max_nr_class = 16;
int nr_class = 0;
int[] label = new int[max_nr_class];
int[] count = new int[max_nr_class];
int[] index = new int[l];
int i;
for(i=0;i<l;i++)
{
int this_label = (int)prob.y[i];
int j;
for(j=0;j<nr_class;j++)
if(this_label == label[j])
{
++count[j];
break;
}
index[i] = j;
if(j == nr_class)
{
if(nr_class == max_nr_class)
{
max_nr_class *= 2;
int[] new_data = new int[max_nr_class];
System.arraycopy(label,0,new_data,0,label.length);
label = new_data;
new_data = new int[max_nr_class];
System.arraycopy(count,0,new_data,0,count.length);
count = new_data;
}
label[nr_class] = this_label;
count[nr_class] = 1;
++nr_class;
}
}
// group training data of the same class
int[] start = new int[nr_class];
start[0] = 0;
for(i=1;i<nr_class;i++)
start[i] = start[i-1]+count[i-1];
svm_node[][] x = new svm_node[l][];
for(i=0;i<l;i++)
{
x[start[index[i]]] = prob.x[i];
++start[index[i]];
}
start[0] = 0;
for(i=1;i<nr_class;i++)
start[i] = start[i-1]+count[i-1];
// calculate weighted C
double[] weighted_C = new double[nr_class];
for(i=0;i<nr_class;i++)
weighted_C[i] = param.C;
for(i=0;i<param.nr_weight;i++)
{
int j;
for(j=0;j<nr_class;j++)
if(param.weight_label[i] == label[j])
break;
if(j == nr_class)
System.err.print("warning: class label "+param.weight_label[i]+" specified in weight is not found\n");
else
weighted_C[j] *= param.weight[i];
}
// train n*(n-1)/2 models
boolean[] nonzero = new boolean[l];
for(i=0;i<l;i++)
nonzero[i] = false;
decision_function[] f = new decision_function[nr_class*(nr_class-1)/2];
int p = 0;
for(i=0;i<nr_class;i++)
for(int j=i+1;j<nr_class;j++)
{
svm_problem sub_prob = new svm_problem();
int si = start[i], sj = start[j];
int ci = count[i], cj = count[j];
sub_prob.l = ci+cj;
sub_prob.x = new svm_node[sub_prob.l][];
sub_prob.y = new double[sub_prob.l];
int k;
for(k=0;k<ci;k++)
{
sub_prob.x[k] = x[si+k];
sub_prob.y[k] = +1;
}
for(k=0;k<cj;k++)
{
sub_prob.x[ci+k] = x[sj+k];
sub_prob.y[ci+k] = -1;
}
f[p] = svm_train_one(sub_prob,param,weighted_C[i],weighted_C[j]);
for(k=0;k<ci;k++)
if(!nonzero[si+k] && Math.abs(f[p].alpha[k]) > 0)
nonzero[si+k] = true;
for(k=0;k<cj;k++)
if(!nonzero[sj+k] && Math.abs(f[p].alpha[ci+k]) > 0)
nonzero[sj+k] = true;
++p;
}
// build output
model.nr_class = nr_class;
model.label = new int[nr_class];
for(i=0;i<nr_class;i++)
model.label[i] = label[i];
model.rho = new double[nr_class*(nr_class-1)/2];
for(i=0;i<nr_class*(nr_class-1)/2;i++)
model.rho[i] = f[i].rho;
int nnz = 0;
int[] nz_count = new int[nr_class];
model.nSV = new int[nr_class];
for(i=0;i<nr_class;i++)
{
int nSV = 0;
for(int j=0;j<count[i];j++)
if(nonzero[start[i]+j])
{
++nSV;
++nnz;
}
model.nSV[i] = nSV;
nz_count[i] = nSV;
}
System.out.print("Total nSV = "+nnz+"\n");
model.l = nnz;
model.SV = new svm_node[nnz][];
p = 0;
for(i=0;i<l;i++)
if(nonzero[i]) model.SV[p++] = x[i];
int[] nz_start = new int[nr_class];
nz_start[0] = 0;
for(i=1;i<nr_class;i++)
nz_start[i] = nz_start[i-1]+nz_count[i-1];
model.sv_coef = new double[nr_class-1][];
for(i=0;i<nr_class-1;i++)
model.sv_coef[i] = new double[nnz];
p = 0;
for(i=0;i<nr_class;i++)
for(int j=i+1;j<nr_class;j++)
{
// classifier (i,j): coefficients with
// i are in sv_coef[j-1][nz_start[i]...],
// j are in sv_coef[i][nz_start[j]...]
int si = start[i];
int sj = start[j];
int ci = count[i];
int cj = count[j];
int q = nz_start[i];
int k;
for(k=0;k<ci;k++)
if(nonzero[si+k])
model.sv_coef[j-1][q++] = f[p].alpha[k];
q = nz_start[j];
for(k=0;k<cj;k++)
if(nonzero[sj+k])
model.sv_coef[i][q++] = f[p].alpha[ci+k];
++p;
}
}
return model;
}
public static double svm_predict(svm_model model, svm_node[] x)
{
if(model.param.svm_type == svm_parameter.ONE_CLASS ||
model.param.svm_type == svm_parameter.EPSILON_SVR ||
model.param.svm_type == svm_parameter.NU_SVR)
{
double[] sv_coef = model.sv_coef[0];
double sum = 0;
for(int i=0;i<model.l;i++)
sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param);
sum -= model.rho[0];
if(model.param.svm_type == svm_parameter.ONE_CLASS)
return (sum>0)?1:-1;
else
return sum;
}
else
{
int i;
int nr_class = model.nr_class;
int l = model.l;
double[] kvalue = new double[l];
for(i=0;i<l;i++)
kvalue[i] = Kernel.k_function(x,model.SV[i],model.param);
int[] start = new int[nr_class];
start[0] = 0;
for(i=1;i<nr_class;i++)
start[i] = start[i-1]+model.nSV[i-1];
int[] vote = new int[nr_class];
for(i=0;i<nr_class;i++)
vote[i] = 0;
int p=0;
for(i=0;i<nr_class;i++)
for(int j=i+1;j<nr_class;j++)
{
double sum = 0;
int si = start[i];
int sj = start[j];
int ci = model.nSV[i];
int cj = model.nSV[j];
int k;
double[] coef1 = model.sv_coef[j-1];
double[] coef2 = model.sv_coef[i];
for(k=0;k<ci;k++)
sum += coef1[si+k] * kvalue[si+k];
for(k=0;k<cj;k++)
sum += coef2[sj+k] * kvalue[sj+k];
sum -= model.rho[p++];
if(sum > 0)
++vote[i];
else
++vote[j];
}
int vote_max_idx = 0;
for(i=1;i<nr_class;i++)
if(vote[i] > vote[vote_max_idx])
vote_max_idx = i;
return model.label[vote_max_idx];
}
}
static final String svm_type_table[] =
{
"c_svc","nu_svc","one_class","epsilon_svr","nu_svr",
};
static final String kernel_type_table[]=
{
"linear","polynomial","rbf","sigmoid",
};
public static void svm_save_model(String model_file_name, svm_model model) throws IOException
{
DataOutputStream fp = new DataOutputStream(new FileOutputStream(model_file_name));
svm_parameter param = model.param;
fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n");
fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n");
if(param.kernel_type == svm_parameter.POLY)
fp.writeBytes("degree "+param.degree+"\n");
if(param.kernel_type == svm_parameter.POLY ||
param.kernel_type == svm_parameter.RBF ||
param.kernel_type == svm_parameter.SIGMOID)
fp.writeBytes("gamma "+param.gamma+"\n");
if(param.kernel_type == svm_parameter.POLY ||
param.kernel_type == svm_parameter.SIGMOID)
fp.writeBytes("coef0 "+param.coef0+"\n");
int nr_class = model.nr_class;
int l = model.l;
fp.writeBytes("nr_class "+nr_class+"\n");
fp.writeBytes("total_sv "+l+"\n");
{
fp.writeBytes("rho");
for(int i=0;i<nr_class*(nr_class-1)/2;i++)
fp.writeBytes(" "+model.rho[i]);
fp.writeBytes("\n");
}
if(model.label != null)
{
fp.writeBytes("label");
for(int i=0;i<nr_class;i++)
fp.writeBytes(" "+model.label[i]);
fp.writeBytes("\n");
}
if(model.nSV != null)
{
fp.writeBytes("nr_sv");
for(int i=0;i<nr_class;i++)
fp.writeBytes(" "+model.nSV[i]);
fp.writeBytes("\n");
}
fp.writeBytes("SV\n");
double[][] sv_coef = model.sv_coef;
svm_node[][] SV = model.SV;
for(int i=0;i<l;i++)
{
for(int j=0;j<nr_class-1;j++)
fp.writeBytes(sv_coef[j][i]+" ");
svm_node[] p = SV[i];
for(int j=0;j<p.length;j++)
fp.writeBytes(p[j].index+":"+p[j].value+" ");
fp.writeBytes("\n");
}
fp.close();
}
private static double atof(String s)
{
return Double.valueOf(s).doubleValue();
}
private static int atoi(String s)
{
return Integer.parseInt(s);
}
public static svm_model svm_load_model(String model_file_name) throws IOException
{
BufferedReader fp = new BufferedReader(new FileReader(model_file_name));
// read parameters
svm_model model = new svm_model();
svm_parameter param = new svm_parameter();
model.param = param;
model.label = null;
model.nSV = null;
while(true)
{
String cmd = fp.readLine();
String arg = cmd.substring(cmd.indexOf(' ')+1);
if(cmd.startsWith("svm_type"))
{
int i;
for(i=0;i<svm_type_table.length;i++)
{
if(arg.indexOf(svm_type_table[i])!=-1)
{
param.svm_type=i;
break;
}
}
if(i == svm_type_table.length)
{
System.err.print("unknown svm type.\n");
System.exit(1);
}
}
else if(cmd.startsWith("kernel_type"))
{
int i;
for(i=0;i<kernel_type_table.length;i++)
{
if(arg.indexOf(kernel_type_table[i])!=-1)
{
param.kernel_type=i;
break;
}
}
if(i == kernel_type_table.length)
{
System.err.print("unknown kernel function.\n");
System.exit(1);
}
}
else if(cmd.startsWith("degree"))
param.degree = atof(arg);
else if(cmd.startsWith("gamma"))
param.gamma = atof(arg);
else if(cmd.startsWith("coef0"))
param.coef0 = atof(arg);
else if(cmd.startsWith("nr_class"))
model.nr_class = atoi(arg);
else if(cmd.startsWith("total_sv"))
model.l = atoi(arg);
else if(cmd.startsWith("rho"))
{
int n = model.nr_class * (model.nr_class-1)/2;
model.rho = new double[n];
StringTokenizer st = new StringTokenizer(arg);
for(int i=0;i<n;i++)
model.rho[i] = atof(st.nextToken());
}
else if(cmd.startsWith("label"))
{
int n = model.nr_class;
model.label = new int[n];
StringTokenizer st = new StringTokenizer(arg);
for(int i=0;i<n;i++)
model.label[i] = atoi(st.nextToken());
}
else if(cmd.startsWith("nr_sv"))
{
int n = model.nr_class;
model.nSV = new int[n];
StringTokenizer st = new StringTokenizer(arg);
for(int i=0;i<n;i++)
model.nSV[i] = atoi(st.nextToken());
}
else if(cmd.startsWith("SV"))
{
break;
}
else
{
System.err.print("unknown text in model file\n");
System.exit(1);
}
}
// read sv_coef and SV
int m = model.nr_class - 1;
int l = model.l;
model.sv_coef = new double[m][l];
model.SV = new svm_node[l][];
for(int i=0;i<l;i++)
{
String line = fp.readLine();
StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
for(int k=0;k<m;k++)
model.sv_coef[k][i] = atof(st.nextToken());
int n = st.countTokens()/2;
model.SV[i] = new svm_node[n];
for(int j=0;j<n;j++)
{
model.SV[i][j] = new svm_node();
model.SV[i][j].index = atoi(st.nextToken());
model.SV[i][j].value = atof(st.nextToken());
}
}
fp.close();
return model;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -