📄 svm.java
字号:
int j,k;
svm_problem subprob = new svm_problem();
subprob.l = l-(end-begin);
subprob.x = new svm_node[subprob.l][];
subprob.y = new double[subprob.l];
k=0;
for(j=0;j<begin;j++)
{
subprob.x[k] = prob.x[perm[j]];
subprob.y[k] = prob.y[perm[j]];
++k;
}
for(j=end;j<l;j++)
{
subprob.x[k] = prob.x[perm[j]];
subprob.y[k] = prob.y[perm[j]];
++k;
}
svm_model submodel = svm_train(subprob,param);
if(param.probability==1 &&
(param.svm_type == svm_parameter.C_SVC ||
param.svm_type == svm_parameter.NU_SVC))
{
double[] prob_estimates= new double[svm_get_nr_class(submodel)];
for(j=begin;j<end;j++)
target[perm[j]] = svm_predict_probability(submodel,prob.x[perm[j]],prob_estimates);
}
else
for(j=begin;j<end;j++)
target[perm[j]] = svm_predict(submodel,prob.x[perm[j]]);
}
}
public static int svm_get_svm_type(svm_model model)
{
return model.param.svm_type;
}
public static int svm_get_nr_class(svm_model model)
{
return model.nr_class;
}
public static void svm_get_labels(svm_model model, int[] label)
{
if (model.label != null)
for(int i=0;i<model.nr_class;i++)
label[i] = model.label[i];
}
public static double svm_get_svr_probability(svm_model model)
{
if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
model.probA!=null)
return model.probA[0];
else
{
System.err.print("Model doesn't contain information for SVR probability inference\n");
return 0;
}
}
public static void svm_predict_values(svm_model model, svm_node[] x, double[] dec_values,int start[])
{
int i;
int nr_class = model.nr_class;
int l = model.l;
int xlen = x.length;
double[] kvalue = new double[l];
for (i = 0; i < l; i++)
kvalue[i] = Kernel.k_function(x, model.SV[i], model.param,xlen);
int p = 0;
int pos = 0;
int si,sj,ci,cj,k;
double sum,coef1[],coef2[];
for (i = 0; i < nr_class; i++)
for (int j = i + 1; j < nr_class; j++)
{
sum = 0;
si = start[i];
sj = start[j];
ci = model.nSV[i];
cj = model.nSV[j];
coef1 = model.sv_coef[j - 1];
coef2 = model.sv_coef[i];
for (k = 0; k < ci; k++)
sum += coef1[si + k] * kvalue[si + k];
for (k = 0; k < cj; k++)
sum += coef2[sj + k] * kvalue[sj + k];
sum -= model.rho[p++];
dec_values[pos++] = sum;
}
}
public static double svm_predict(svm_model model, svm_node[] x)
{
if(model.param.svm_type == svm_parameter.ONE_CLASS ||
model.param.svm_type == svm_parameter.EPSILON_SVR ||
model.param.svm_type == svm_parameter.NU_SVR)
{
double[] res = new double[1];
// svm_predict_values(model, x, res);
if(model.param.svm_type == svm_parameter.ONE_CLASS)
return (res[0]>0)?1:-1;
else
return res[0];
}
else
{
int i;
int nr_class = model.nr_class;
double[] dec_values = new double[nr_class*(nr_class-1)/2];
// svm_predict_values(model, x, dec_values);
int[] vote = new int[nr_class];
for(i=0;i<nr_class;i++)
vote[i] = 0;
int pos=0;
for(i=0;i<nr_class;i++)
for(int j=i+1;j<nr_class;j++)
{
if(dec_values[pos++] > 0)
++vote[i];
else
++vote[j];
}
int vote_max_idx = 0;
for(i=1;i<nr_class;i++)
if(vote[i] > vote[vote_max_idx])
vote_max_idx = i;
return model.label[vote_max_idx];
}
}
public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates)
{
if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
model.probA!=null && model.probB!=null)
{
int i;
int nr_class = model.nr_class;
double[] dec_values = new double[nr_class*(nr_class-1)/2];
// svm_predict_values(model, x, dec_values);
double min_prob=1e-7;
double[][] pairwise_prob=new double[nr_class][nr_class];
int k=0;
for(i=0;i<nr_class;i++)
for(int j=i+1;j<nr_class;j++)
{
pairwise_prob[i][j]=Math.min(Math.max(sigmoid_predict(dec_values[k],model.probA[k],model.probB[k]),min_prob),1-min_prob);
pairwise_prob[j][i]=1-pairwise_prob[i][j];
k++;
}
multiclass_probability(nr_class,pairwise_prob,prob_estimates);
int prob_max_idx = 0;
for(i=1;i<nr_class;i++)
if(prob_estimates[i] > prob_estimates[prob_max_idx])
prob_max_idx = i;
return model.label[prob_max_idx];
}
else
return svm_predict(model, x);
}
static final String svm_type_table[] =
{
"c_svc","nu_svc","one_class","epsilon_svr","nu_svr",
};
static final String kernel_type_table[]=
{
"linear","polynomial","rbf","sigmoid",
};
public static void svm_save_model(String model_file_name, svm_model model) throws IOException
{
DataOutputStream fp = new DataOutputStream(new FileOutputStream(model_file_name));
svm_parameter param = model.param;
fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n");
fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n");
if(param.kernel_type == svm_parameter.POLY)
fp.writeBytes("degree "+param.degree+"\n");
if(param.kernel_type == svm_parameter.POLY ||
param.kernel_type == svm_parameter.RBF ||
param.kernel_type == svm_parameter.SIGMOID)
fp.writeBytes("gamma "+param.gamma+"\n");
if(param.kernel_type == svm_parameter.POLY ||
param.kernel_type == svm_parameter.SIGMOID)
fp.writeBytes("coef0 "+param.coef0+"\n");
int nr_class = model.nr_class;
int l = model.l;
fp.writeBytes("nr_class "+nr_class+"\n");
fp.writeBytes("total_sv "+l+"\n");
{
fp.writeBytes("rho");
for(int i=0;i<nr_class*(nr_class-1)/2;i++)
fp.writeBytes(" "+model.rho[i]);
fp.writeBytes("\n");
}
if(model.label != null)
{
fp.writeBytes("label");
for(int i=0;i<nr_class;i++)
fp.writeBytes(" "+model.label[i]);
fp.writeBytes("\n");
}
if(model.probA != null) // regression has probA only
{
fp.writeBytes("probA");
for(int i=0;i<nr_class*(nr_class-1)/2;i++)
fp.writeBytes(" "+model.probA[i]);
fp.writeBytes("\n");
}
if(model.probB != null)
{
fp.writeBytes("probB");
for(int i=0;i<nr_class*(nr_class-1)/2;i++)
fp.writeBytes(" "+model.probB[i]);
fp.writeBytes("\n");
}
if(model.nSV != null)
{
fp.writeBytes("nr_sv");
for(int i=0;i<nr_class;i++)
fp.writeBytes(" "+model.nSV[i]);
fp.writeBytes("\n");
}
fp.writeBytes("SV\n");
double[][] sv_coef = model.sv_coef;
svm_node[][] SV = model.SV;
for(int i=0;i<l;i++)
{
for(int j=0;j<nr_class-1;j++)
fp.writeBytes(sv_coef[j][i]+" ");
svm_node[] p = SV[i];
for(int j=0;j<p.length;j++)
fp.writeBytes(p[j].index+":"+p[j].value+" ");
fp.writeBytes("\n");
}
fp.close();
}
private static double atof(String s)
{
return Double.valueOf(s).doubleValue();
}
private static int atoi(String s)
{
return Integer.parseInt(s);
}
public static svm_model svm_load_model(String model_file_name) throws IOException
{
BufferedReader fp = new BufferedReader(new FileReader(model_file_name));
// read parameters
svm_model model = new svm_model();
svm_parameter param = new svm_parameter();
model.param = param;
model.rho = null;
model.probA = null;
model.probB = null;
model.label = null;
model.nSV = null;
while(true)
{
String cmd = fp.readLine();
String arg = cmd.substring(cmd.indexOf(' ')+1);
if(cmd.startsWith("svm_type"))
{
int i;
for(i=0;i<svm_type_table.length;i++)
{
if(arg.indexOf(svm_type_table[i])!=-1)
{
param.svm_type=i;
break;
}
}
if(i == svm_type_table.length)
{
System.err.print("unknown svm type.\n");
return null;
}
}
else if(cmd.startsWith("kernel_type"))
{
int i;
for(i=0;i<kernel_type_table.length;i++)
{
if(arg.indexOf(kernel_type_table[i])!=-1)
{
param.kernel_type=i;
break;
}
}
if(i == kernel_type_table.length)
{
System.err.print("unknown kernel function.\n");
return null;
}
}
else if(cmd.startsWith("degree"))
param.degree = atof(arg);
else if(cmd.startsWith("gamma"))
param.gamma = atof(arg);
else if(cmd.startsWith("coef0"))
param.coef0 = atof(arg);
else if(cmd.startsWith("nr_class"))
model.nr_class = atoi(arg);
else if(cmd.startsWith("total_sv"))
model.l = atoi(arg);
else if(cmd.startsWith("rho"))
{
int n = model.nr_class * (model.nr_class-1)/2;
model.rho = new double[n];
StringTokenizer st = new StringTokenizer(arg);
for(int i=0;i<n;i++)
model.rho[i] = atof(st.nextToken());
}
else if(cmd.startsWith("label"))
{
int n = model.nr_class;
model.label = new int[n];
StringTokenizer st = new StringTokenizer(arg);
for(int i=0;i<n;i++)
model.label[i] = atoi(st.nextToken());
}
else if(cmd.startsWith("probA"))
{
int n = model.nr_class*(model.nr_class-1)/2;
model.probA = new double[n];
StringTokenizer st = new StringTokenizer(arg);
for(int i=0;i<n;i++)
model.probA[i] = atof(st.nextToken());
}
else if(cmd.startsWith("probB"))
{
int n = model.nr_class*(model.nr_class-1)/2;
model.probB = new double[n];
StringTokenizer st = new StringTokenizer(arg);
for(int i=0;i<n;i++)
model.probB[i] = atof(st.nextToken());
}
else if(cmd.startsWith("nr_sv"))
{
int n = model.nr_class;
model.nSV = new int[n];
StringTokenizer st = new StringTokenizer(arg);
for(int i=0;i<n;i++)
model.nSV[i] = atoi(st.nextToken());
}
else if(cmd.startsWith("SV"))
{
break;
}
else
{
System.err.print("unknown text in model file\n");
return null;
}
}
// read sv_coef and SV
int m = model.nr_class - 1;
int l = model.l;
model.sv_coef = new double[m][l];
model.SV = new svm_node[l][];
for(int i=0;i<l;i++)
{
String line = fp.readLine();
StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
for(int k=0;k<m;k++)
model.sv_coef[k][i] = atof(st.nextToken());
int n = st.countTokens()/2;
model.SV[i] = new svm_node[n];
for(int j=0;j<n;j++)
{
model.SV[i][j] = new svm_node();
model.SV[i][j].index = atoi(st.nextToken());
model.SV[i][j].value = atof(st.nextToken());
}
}
fp.close();
return model;
}
public static String svm_check_parameter(svm_problem prob, svm_parameter param)
{
// svm_type
int svm_type = param.svm_type;
if(svm_type != svm_parameter.C_SVC &&
svm_type != svm_parameter.NU_SVC &&
svm_type != svm_parameter.ONE_CLASS &&
svm_type != svm_parameter.EPSILON_SVR &&
svm_type != svm_parameter.NU_SVR)
return "unknown svm type";
// kernel_type
int kernel_type = param.kernel_type;
if(kernel_type != svm_parameter.LINEAR &&
kernel_type != svm_parameter.POLY &&
kernel_type != svm_parameter.RBF &&
kernel_type != svm_parameter.SIGMOID)
return "unknown kernel type";
// cache_size,eps,C,nu,p,shrinking
if(param.cache_size <= 0)
return "cache_size <= 0";
if(param.eps <= 0)
return "eps <= 0";
if(svm_type == svm_parameter.C_SVC ||
svm_type == svm_parameter.EPSILON_SVR ||
svm_type == svm_parameter.NU_SVR)
if(param.C <= 0)
return "C <= 0";
if(svm_type == svm_parameter.NU_SVC ||
svm_type == svm_parameter.ONE_CLASS ||
svm_type == svm_parameter.NU_SVR)
if(param.nu <= 0 || param.nu > 1)
return "nu <= 0 or nu > 1";
if(svm_type == svm_parameter.EPSILON_SVR)
if(param.p < 0)
return "p < 0";
if(param.shrinking != 0 &&
param.shrinking != 1)
return "shrinking != 0 and shrinking != 1";
if(param.probability != 0 &&
param.probability != 1)
return "probability != 0 and probability != 1";
if(param.probability == 1 &&
svm_type == svm_parameter.ONE_CLASS)
return "one-class SVM probability output not supported yet";
// check whether nu-svc is feasible
if(svm_type == svm_parameter.NU_SVC)
{
int l = prob.l;
int max_nr_class = 16;
int nr_class = 0;
int[] label = new int[max_nr_class];
int[] count = new int[max_nr_class];
int i;
for(i=0;i<l;i++)
{
int this_label = (int)prob.y[i];
int j;
for(j=0;j<nr_class;j++)
if(this_label == label[j])
{
++count[j];
break;
}
if(j == nr_class)
{
if(nr_class == max_nr_class)
{
max_nr_class *= 2;
int[] new_data = new int[max_nr_class];
System.arraycopy(label,0,new_data,0,label.length);
label = new_data;
new_data = new int[max_nr_class];
System.arraycopy(count,0,new_data,0,count.length);
count = new_data;
}
label[nr_class] = this_label;
count[nr_class] = 1;
++nr_class;
}
}
for(i=0;i<nr_class;i++)
{
int n1 = count[i];
for(int j=i+1;j<nr_class;j++)
{
int n2 = count[j];
if(param.nu*(n1+n2)/2 > Math.min(n1,n2))
return "specified nu is infeasible";
}
}
}
return null;
}
public static int svm_check_probability_model(svm_model model)
{
if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&
model.probA!=null && model.probB!=null) ||
((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&
model.probA!=null))
return 1;
else
return 0;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -