📄 svm.java
字号:
for(j=0;j<begin;j++) { subprob.x[k] = prob.x[perm[j]]; subprob.y[k] = prob.y[perm[j]]; ++k; } for(j=end;j<l;j++) { subprob.x[k] = prob.x[perm[j]]; subprob.y[k] = prob.y[perm[j]]; ++k; } svm_model submodel = svm_train(subprob,param); if(param.probability==1 && (param.svm_type == svm_parameter.C_SVC || param.svm_type == svm_parameter.NU_SVC)) { double[] prob_estimates= new double[svm_get_nr_class(submodel)]; for(j=begin;j<end;j++) target[perm[j]] = svm_predict_probability(submodel,prob.x[perm[j]],prob_estimates); } else for(j=begin;j<end;j++) target[perm[j]] = svm_predict(submodel,prob.x[perm[j]]); } } public static int svm_get_svm_type(svm_model model) { return model.param.svm_type; } public static int svm_get_nr_class(svm_model model) { return model.nr_class; } public static void svm_get_labels(svm_model model, int[] label) { if (model.label != null) for(int i=0;i<model.nr_class;i++) label[i] = model.label[i]; } public static double svm_get_svr_probability(svm_model model) { if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) && model.probA!=null) return model.probA[0]; else { System.err.print("Model doesn't contain information for SVR probability inference\n"); return 0; } } public static void svm_predict_values(svm_model model, svm_node[] x, double[] dec_values) { if(model.param.svm_type == svm_parameter.ONE_CLASS || model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) { double[] sv_coef = model.sv_coef[0]; double sum = 0; for(int i=0;i<model.l;i++) sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param); sum -= model.rho[0]; dec_values[0] = sum; } else { int i; int nr_class = model.nr_class; int l = model.l; double[] kvalue = new double[l]; for(i=0;i<l;i++) kvalue[i] = Kernel.k_function(x,model.SV[i],model.param); int[] start = new int[nr_class]; start[0] = 0; for(i=1;i<nr_class;i++) start[i] = start[i-1]+model.nSV[i-1]; int p=0; int pos=0; for(i=0;i<nr_class;i++) for(int j=i+1;j<nr_class;j++) { double sum = 0; int si = start[i]; int sj = start[j]; int ci = model.nSV[i]; int cj = model.nSV[j]; int k; double[] coef1 = model.sv_coef[j-1]; double[] coef2 = model.sv_coef[i]; for(k=0;k<ci;k++) sum += coef1[si+k] * kvalue[si+k]; for(k=0;k<cj;k++) sum += coef2[sj+k] * kvalue[sj+k]; sum -= model.rho[p++]; dec_values[pos++] = sum; } } } public static double svm_predict(svm_model model, svm_node[] x) { if(model.param.svm_type == svm_parameter.ONE_CLASS || model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) { double[] res = new double[1]; svm_predict_values(model, x, res); if(model.param.svm_type == svm_parameter.ONE_CLASS) return (res[0]>0)?1:-1; else return res[0]; } else { int i; int nr_class = model.nr_class; double[] dec_values = new double[nr_class*(nr_class-1)/2]; svm_predict_values(model, x, dec_values); int[] vote = new int[nr_class]; for(i=0;i<nr_class;i++) vote[i] = 0; int pos=0; for(i=0;i<nr_class;i++) for(int j=i+1;j<nr_class;j++) { if(dec_values[pos++] > 0) ++vote[i]; else ++vote[j]; } int vote_max_idx = 0; for(i=1;i<nr_class;i++) if(vote[i] > vote[vote_max_idx]) vote_max_idx = i; return model.label[vote_max_idx]; } } public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates) { if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) && model.probA!=null && model.probB!=null) { int i; int nr_class = model.nr_class; double[] dec_values = new double[nr_class*(nr_class-1)/2]; svm_predict_values(model, x, dec_values); double min_prob=1e-7; double[][] pairwise_prob=new double[nr_class][nr_class]; int k=0; for(i=0;i<nr_class;i++) for(int j=i+1;j<nr_class;j++) { pairwise_prob[i][j]=Math.min(Math.max(sigmoid_predict(dec_values[k],model.probA[k],model.probB[k]),min_prob),1-min_prob); pairwise_prob[j][i]=1-pairwise_prob[i][j]; k++; } multiclass_probability(nr_class,pairwise_prob,prob_estimates); int prob_max_idx = 0; for(i=1;i<nr_class;i++) if(prob_estimates[i] > prob_estimates[prob_max_idx]) prob_max_idx = i; return model.label[prob_max_idx]; } else return svm_predict(model, x); } static final String svm_type_table[] = { "c_svc","nu_svc","one_class","epsilon_svr","nu_svr", }; static final String kernel_type_table[]= { "linear","polynomial","rbf","sigmoid","precomputed" }; public static void svm_save_model(String model_file_name, svm_model model) throws IOException { DataOutputStream fp = new DataOutputStream(new FileOutputStream(model_file_name)); svm_parameter param = model.param; fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n"); fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n"); if(param.kernel_type == svm_parameter.POLY) fp.writeBytes("degree "+param.degree+"\n"); if(param.kernel_type == svm_parameter.POLY || param.kernel_type == svm_parameter.RBF || param.kernel_type == svm_parameter.SIGMOID) fp.writeBytes("gamma "+param.gamma+"\n"); if(param.kernel_type == svm_parameter.POLY || param.kernel_type == svm_parameter.SIGMOID) fp.writeBytes("coef0 "+param.coef0+"\n"); int nr_class = model.nr_class; int l = model.l; fp.writeBytes("nr_class "+nr_class+"\n"); fp.writeBytes("total_sv "+l+"\n"); { fp.writeBytes("rho"); for(int i=0;i<nr_class*(nr_class-1)/2;i++) fp.writeBytes(" "+model.rho[i]); fp.writeBytes("\n"); } if(model.label != null) { fp.writeBytes("label"); for(int i=0;i<nr_class;i++) fp.writeBytes(" "+model.label[i]); fp.writeBytes("\n"); } if(model.probA != null) // regression has probA only { fp.writeBytes("probA"); for(int i=0;i<nr_class*(nr_class-1)/2;i++) fp.writeBytes(" "+model.probA[i]); fp.writeBytes("\n"); } if(model.probB != null) { fp.writeBytes("probB"); for(int i=0;i<nr_class*(nr_class-1)/2;i++) fp.writeBytes(" "+model.probB[i]); fp.writeBytes("\n"); } if(model.nSV != null) { fp.writeBytes("nr_sv"); for(int i=0;i<nr_class;i++) fp.writeBytes(" "+model.nSV[i]); fp.writeBytes("\n"); } fp.writeBytes("SV\n"); double[][] sv_coef = model.sv_coef; svm_node[][] SV = model.SV; for(int i=0;i<l;i++) { for(int j=0;j<nr_class-1;j++) fp.writeBytes(sv_coef[j][i]+" "); svm_node[] p = SV[i]; if(param.kernel_type == svm_parameter.PRECOMPUTED) fp.writeBytes("0:"+(int)(p[0].value)); else for(int j=0;j<p.length;j++) fp.writeBytes(p[j].index+":"+p[j].value+" "); fp.writeBytes("\n"); } fp.close(); } private static double atof(String s) { return Double.valueOf(s).doubleValue(); } private static int atoi(String s) { return Integer.parseInt(s); } public static svm_model svm_load_model(String model_file_name) throws IOException { BufferedReader fp = new BufferedReader(new FileReader(model_file_name)); // read parameters svm_model model = new svm_model(); svm_parameter param = new svm_parameter(); model.param = param; model.rho = null; model.probA = null; model.probB = null; model.label = null; model.nSV = null; while(true) { String cmd = fp.readLine(); String arg = cmd.substring(cmd.indexOf(' ')+1); if(cmd.startsWith("svm_type")) { int i; for(i=0;i<svm_type_table.length;i++) { if(arg.indexOf(svm_type_table[i])!=-1) { param.svm_type=i; break; } } if(i == svm_type_table.length) { System.err.print("unknown svm type.\n"); return null; } } else if(cmd.startsWith("kernel_type")) { int i; for(i=0;i<kernel_type_table.length;i++) { if(arg.indexOf(kernel_type_table[i])!=-1) { param.kernel_type=i; break; } } if(i == kernel_type_table.length) { System.err.print("unknown kernel function.\n"); return null; } } else if(cmd.startsWith("degree")) param.degree = atoi(arg); else if(cmd.startsWith("gamma")) param.gamma = atof(arg); else if(cmd.startsWith("coef0")) param.coef0 = atof(arg); else if(cmd.startsWith("nr_class")) model.nr_class = atoi(arg); else if(cmd.startsWith("total_sv")) model.l = atoi(arg); else if(cmd.startsWith("rho")) { int n = model.nr_class * (model.nr_class-1)/2; model.rho = new double[n]; StringTokenizer st = new StringTokenizer(arg); for(int i=0;i<n;i++) model.rho[i] = atof(st.nextToken()); } else if(cmd.startsWith("label")) { int n = model.nr_class; model.label = new int[n]; StringTokenizer st = new StringTokenizer(arg); for(int i=0;i<n;i++) model.label[i] = atoi(st.nextToken()); } else if(cmd.startsWith("probA")) { int n = model.nr_class*(model.nr_class-1)/2; model.probA = new double[n]; StringTokenizer st = new StringTokenizer(arg); for(int i=0;i<n;i++) model.probA[i] = atof(st.nextToken()); } else if(cmd.startsWith("probB")) { int n = model.nr_class*(model.nr_class-1)/2; model.probB = new double[n]; StringTokenizer st = new StringTokenizer(arg); for(int i=0;i<n;i++) model.probB[i] = atof(st.nextToken()); } else if(cmd.startsWith("nr_sv")) { int n = model.nr_class; model.nSV = new int[n]; StringTokenizer st = new StringTokenizer(arg); for(int i=0;i<n;i++) model.nSV[i] = atoi(st.nextToken()); } else if(cmd.startsWith("SV")) { break; } else { System.err.print("unknown text in model file\n"); return null; } } // read sv_coef and SV int m = model.nr_class - 1; int l = model.l; model.sv_coef = new double[m][l]; model.SV = new svm_node[l][]; for(int i=0;i<l;i++) { String line = fp.readLine(); StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); for(int k=0;k<m;k++) model.sv_coef[k][i] = atof(st.nextToken()); int n = st.countTokens()/2; model.SV[i] = new svm_node[n]; for(int j=0;j<n;j++) { model.SV[i][j] = new svm_node(); model.SV[i][j].index = atoi(st.nextToken()); model.SV[i][j].value = atof(st.nextToken()); } } fp.close(); return model; } public static String svm_check_parameter(svm_problem prob, svm_parameter param) { // svm_type int svm_type = param.svm_type; if(svm_type != svm_parameter.C_SVC && svm_type != svm_parameter.NU_SVC && svm_type != svm_parameter.ONE_CLASS && svm_type != svm_parameter.EPSILON_SVR && svm_type != svm_parameter.NU_SVR) return "unknown svm type"; // kernel_type, degree int kernel_type = param.kernel_type; if(kernel_type != svm_parameter.LINEAR && kernel_type != svm_parameter.POLY && kernel_type != svm_parameter.RBF && kernel_type != svm_parameter.SIGMOID && kernel_type != svm_parameter.PRECOMPUTED) return "unknown kernel type"; if(param.degree < 0) return "degree of polynomial kernel < 0"; // cache_size,eps,C,nu,p,shrinking if(param.cache_size <= 0) return "cache_size <= 0"; if(param.eps <= 0) return "eps <= 0"; if(svm_type == svm_parameter.C_SVC || svm_type == svm_parameter.EPSILON_SVR || svm_type == svm_parameter.NU_SVR) if(param.C <= 0) return "C <= 0"; if(svm_type == svm_parameter.NU_SVC || svm_type == svm_parameter.ONE_CLASS || svm_type == svm_parameter.NU_SVR) if(param.nu <= 0 || param.nu > 1) return "nu <= 0 or nu > 1"; if(svm_type == svm_parameter.EPSILON_SVR) if(param.p < 0) return "p < 0"; if(param.shrinking != 0 && param.shrinking != 1) return "shrinking != 0 and shrinking != 1"; if(param.probability != 0 && param.probability != 1) return "probability != 0 and probability != 1"; if(param.probability == 1 && svm_type == svm_parameter.ONE_CLASS) return "one-class SVM probability output not supported yet"; // check whether nu-svc is feasible if(svm_type == svm_parameter.NU_SVC) { int l = prob.l; int max_nr_class = 16; int nr_class = 0; int[] label = new int[max_nr_class]; int[] count = new int[max_nr_class]; int i; for(i=0;i<l;i++) { int this_label = (int)prob.y[i]; int j; for(j=0;j<nr_class;j++) if(this_label == label[j]) { ++count[j]; break; } if(j == nr_class) { if(nr_class == max_nr_class) { max_nr_class *= 2; int[] new_data = new int[max_nr_class]; System.arraycopy(label,0,new_data,0,label.length); label = new_data; new_data = new int[max_nr_class]; System.arraycopy(count,0,new_data,0,count.length); count = new_data; } label[nr_class] = this_label; count[nr_class] = 1; ++nr_class; } } for(i=0;i<nr_class;i++) { int n1 = count[i]; for(int j=i+1;j<nr_class;j++) { int n2 = count[j]; if(param.nu*(n1+n2)/2 > Math.min(n1,n2)) return "specified nu is infeasible"; } } } return null; } public static int svm_check_probability_model(svm_model model) { if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) && model.probA!=null && model.probB!=null) || ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) && model.probA!=null)) return 1; else return 0; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -