⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm.java

📁 SVM是一种常用的模式分类机器学习算法
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
			for(j=0;j<begin;j++)			{				subprob.x[k] = prob.x[perm[j]];				subprob.y[k] = prob.y[perm[j]];				++k;			}			for(j=end;j<l;j++)			{				subprob.x[k] = prob.x[perm[j]];				subprob.y[k] = prob.y[perm[j]];				++k;			}			svm_model submodel = svm_train(subprob,param);			if(param.probability==1 &&			   (param.svm_type == svm_parameter.C_SVC ||			    param.svm_type == svm_parameter.NU_SVC))			{				double[] prob_estimates= new double[svm_get_nr_class(submodel)];				for(j=begin;j<end;j++)					target[perm[j]] = svm_predict_probability(submodel,prob.x[perm[j]],prob_estimates);			}			else				for(j=begin;j<end;j++)					target[perm[j]] = svm_predict(submodel,prob.x[perm[j]]);		}	}	public static int svm_get_svm_type(svm_model model)	{		return model.param.svm_type;	}	public static int svm_get_nr_class(svm_model model)	{		return model.nr_class;	}	public static void svm_get_labels(svm_model model, int[] label)	{		if (model.label != null)			for(int i=0;i<model.nr_class;i++)				label[i] = model.label[i];	}	public static double svm_get_svr_probability(svm_model model)	{		if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&		    model.probA!=null)		return model.probA[0];		else		{			System.err.print("Model doesn't contain information for SVR probability inference\n");			return 0;		}	}	public static void svm_predict_values(svm_model model, svm_node[] x, double[] dec_values)	{		if(model.param.svm_type == svm_parameter.ONE_CLASS ||		   model.param.svm_type == svm_parameter.EPSILON_SVR ||		   model.param.svm_type == svm_parameter.NU_SVR)		{			double[] sv_coef = model.sv_coef[0];			double sum = 0;			for(int i=0;i<model.l;i++)				sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param);			sum -= model.rho[0];			dec_values[0] = sum;		}		else		{			int i;			int nr_class = model.nr_class;			int l = model.l;					double[] kvalue = new double[l];			for(i=0;i<l;i++)				kvalue[i] = Kernel.k_function(x,model.SV[i],model.param);			int[] start = new int[nr_class];			start[0] = 0;			for(i=1;i<nr_class;i++)				start[i] = start[i-1]+model.nSV[i-1];			int p=0;			int pos=0;			for(i=0;i<nr_class;i++)				for(int j=i+1;j<nr_class;j++)				{					double sum = 0;					int si = start[i];					int sj = start[j];					int ci = model.nSV[i];					int cj = model.nSV[j];									int k;					double[] coef1 = model.sv_coef[j-1];					double[] coef2 = model.sv_coef[i];					for(k=0;k<ci;k++)						sum += coef1[si+k] * kvalue[si+k];					for(k=0;k<cj;k++)						sum += coef2[sj+k] * kvalue[sj+k];					sum -= model.rho[p++];					dec_values[pos++] = sum;									}		}	}	public static double svm_predict(svm_model model, svm_node[] x)	{		if(model.param.svm_type == svm_parameter.ONE_CLASS ||		   model.param.svm_type == svm_parameter.EPSILON_SVR ||		   model.param.svm_type == svm_parameter.NU_SVR)		{			double[] res = new double[1];			svm_predict_values(model, x, res);			if(model.param.svm_type == svm_parameter.ONE_CLASS)				return (res[0]>0)?1:-1;			else				return res[0];		}		else		{			int i;			int nr_class = model.nr_class;			double[] dec_values = new double[nr_class*(nr_class-1)/2];			svm_predict_values(model, x, dec_values);			int[] vote = new int[nr_class];			for(i=0;i<nr_class;i++)				vote[i] = 0;			int pos=0;			for(i=0;i<nr_class;i++)				for(int j=i+1;j<nr_class;j++)				{					if(dec_values[pos++] > 0)						++vote[i];					else						++vote[j];				}			int vote_max_idx = 0;			for(i=1;i<nr_class;i++)				if(vote[i] > vote[vote_max_idx])					vote_max_idx = i;			return model.label[vote_max_idx];		}	}	public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates)	{		if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&		    model.probA!=null && model.probB!=null)		{			int i;			int nr_class = model.nr_class;			double[] dec_values = new double[nr_class*(nr_class-1)/2];			svm_predict_values(model, x, dec_values);			double min_prob=1e-7;			double[][] pairwise_prob=new double[nr_class][nr_class];						int k=0;			for(i=0;i<nr_class;i++)				for(int j=i+1;j<nr_class;j++)				{					pairwise_prob[i][j]=Math.min(Math.max(sigmoid_predict(dec_values[k],model.probA[k],model.probB[k]),min_prob),1-min_prob);					pairwise_prob[j][i]=1-pairwise_prob[i][j];					k++;				}			multiclass_probability(nr_class,pairwise_prob,prob_estimates);			int prob_max_idx = 0;			for(i=1;i<nr_class;i++)				if(prob_estimates[i] > prob_estimates[prob_max_idx])					prob_max_idx = i;			return model.label[prob_max_idx];		}		else 			return svm_predict(model, x);	}	static final String svm_type_table[] =	{		"c_svc","nu_svc","one_class","epsilon_svr","nu_svr",	};	static final String kernel_type_table[]=	{		"linear","polynomial","rbf","sigmoid","precomputed"	};	public static void svm_save_model(String model_file_name, svm_model model) throws IOException	{		DataOutputStream fp = new DataOutputStream(new FileOutputStream(model_file_name));		svm_parameter param = model.param;		fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n");		fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n");		if(param.kernel_type == svm_parameter.POLY)			fp.writeBytes("degree "+param.degree+"\n");		if(param.kernel_type == svm_parameter.POLY ||		   param.kernel_type == svm_parameter.RBF ||		   param.kernel_type == svm_parameter.SIGMOID)			fp.writeBytes("gamma "+param.gamma+"\n");		if(param.kernel_type == svm_parameter.POLY ||		   param.kernel_type == svm_parameter.SIGMOID)			fp.writeBytes("coef0 "+param.coef0+"\n");		int nr_class = model.nr_class;		int l = model.l;		fp.writeBytes("nr_class "+nr_class+"\n");		fp.writeBytes("total_sv "+l+"\n");			{			fp.writeBytes("rho");			for(int i=0;i<nr_class*(nr_class-1)/2;i++)				fp.writeBytes(" "+model.rho[i]);			fp.writeBytes("\n");		}			if(model.label != null)		{			fp.writeBytes("label");			for(int i=0;i<nr_class;i++)				fp.writeBytes(" "+model.label[i]);			fp.writeBytes("\n");		}		if(model.probA != null) // regression has probA only		{			fp.writeBytes("probA");			for(int i=0;i<nr_class*(nr_class-1)/2;i++)				fp.writeBytes(" "+model.probA[i]);			fp.writeBytes("\n");		}		if(model.probB != null) 		{			fp.writeBytes("probB");			for(int i=0;i<nr_class*(nr_class-1)/2;i++)				fp.writeBytes(" "+model.probB[i]);			fp.writeBytes("\n");		}		if(model.nSV != null)		{			fp.writeBytes("nr_sv");			for(int i=0;i<nr_class;i++)				fp.writeBytes(" "+model.nSV[i]);			fp.writeBytes("\n");		}		fp.writeBytes("SV\n");		double[][] sv_coef = model.sv_coef;		svm_node[][] SV = model.SV;		for(int i=0;i<l;i++)		{			for(int j=0;j<nr_class-1;j++)				fp.writeBytes(sv_coef[j][i]+" ");			svm_node[] p = SV[i];			if(param.kernel_type == svm_parameter.PRECOMPUTED)				fp.writeBytes("0:"+(int)(p[0].value));			else					for(int j=0;j<p.length;j++)					fp.writeBytes(p[j].index+":"+p[j].value+" ");			fp.writeBytes("\n");		}		fp.close();	}	private static double atof(String s)	{		return Double.valueOf(s).doubleValue();	}	private static int atoi(String s)	{		return Integer.parseInt(s);	}	public static svm_model svm_load_model(String model_file_name) throws IOException	{		BufferedReader fp = new BufferedReader(new FileReader(model_file_name));		// read parameters		svm_model model = new svm_model();		svm_parameter param = new svm_parameter();		model.param = param;		model.rho = null;		model.probA = null;		model.probB = null;		model.label = null;		model.nSV = null;		while(true)		{			String cmd = fp.readLine();			String arg = cmd.substring(cmd.indexOf(' ')+1);			if(cmd.startsWith("svm_type"))			{				int i;				for(i=0;i<svm_type_table.length;i++)				{					if(arg.indexOf(svm_type_table[i])!=-1)					{						param.svm_type=i;						break;					}				}				if(i == svm_type_table.length)				{					System.err.print("unknown svm type.\n");					return null;				}			}			else if(cmd.startsWith("kernel_type"))			{				int i;				for(i=0;i<kernel_type_table.length;i++)				{					if(arg.indexOf(kernel_type_table[i])!=-1)					{						param.kernel_type=i;						break;					}				}				if(i == kernel_type_table.length)				{					System.err.print("unknown kernel function.\n");					return null;				}			}			else if(cmd.startsWith("degree"))				param.degree = atoi(arg);			else if(cmd.startsWith("gamma"))				param.gamma = atof(arg);			else if(cmd.startsWith("coef0"))				param.coef0 = atof(arg);			else if(cmd.startsWith("nr_class"))				model.nr_class = atoi(arg);			else if(cmd.startsWith("total_sv"))				model.l = atoi(arg);			else if(cmd.startsWith("rho"))			{				int n = model.nr_class * (model.nr_class-1)/2;				model.rho = new double[n];				StringTokenizer st = new StringTokenizer(arg);				for(int i=0;i<n;i++)					model.rho[i] = atof(st.nextToken());			}			else if(cmd.startsWith("label"))			{				int n = model.nr_class;				model.label = new int[n];				StringTokenizer st = new StringTokenizer(arg);				for(int i=0;i<n;i++)					model.label[i] = atoi(st.nextToken());								}			else if(cmd.startsWith("probA"))			{				int n = model.nr_class*(model.nr_class-1)/2;				model.probA = new double[n];				StringTokenizer st = new StringTokenizer(arg);				for(int i=0;i<n;i++)					model.probA[i] = atof(st.nextToken());								}			else if(cmd.startsWith("probB"))			{				int n = model.nr_class*(model.nr_class-1)/2;				model.probB = new double[n];				StringTokenizer st = new StringTokenizer(arg);				for(int i=0;i<n;i++)					model.probB[i] = atof(st.nextToken());								}			else if(cmd.startsWith("nr_sv"))			{				int n = model.nr_class;				model.nSV = new int[n];				StringTokenizer st = new StringTokenizer(arg);				for(int i=0;i<n;i++)					model.nSV[i] = atoi(st.nextToken());			}			else if(cmd.startsWith("SV"))			{				break;			}			else			{				System.err.print("unknown text in model file\n");				return null;			}		}		// read sv_coef and SV		int m = model.nr_class - 1;		int l = model.l;		model.sv_coef = new double[m][l];		model.SV = new svm_node[l][];		for(int i=0;i<l;i++)		{			String line = fp.readLine();			StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");			for(int k=0;k<m;k++)				model.sv_coef[k][i] = atof(st.nextToken());			int n = st.countTokens()/2;			model.SV[i] = new svm_node[n];			for(int j=0;j<n;j++)			{				model.SV[i][j] = new svm_node();				model.SV[i][j].index = atoi(st.nextToken());				model.SV[i][j].value = atof(st.nextToken());			}		}		fp.close();		return model;	}	public static String svm_check_parameter(svm_problem prob, svm_parameter param)	{		// svm_type		int svm_type = param.svm_type;		if(svm_type != svm_parameter.C_SVC &&		   svm_type != svm_parameter.NU_SVC &&		   svm_type != svm_parameter.ONE_CLASS &&		   svm_type != svm_parameter.EPSILON_SVR &&		   svm_type != svm_parameter.NU_SVR)		return "unknown svm type";		// kernel_type, degree			int kernel_type = param.kernel_type;		if(kernel_type != svm_parameter.LINEAR &&		   kernel_type != svm_parameter.POLY &&		   kernel_type != svm_parameter.RBF &&		   kernel_type != svm_parameter.SIGMOID &&		   kernel_type != svm_parameter.PRECOMPUTED)			return "unknown kernel type";		if(param.degree < 0)			return "degree of polynomial kernel < 0";		// cache_size,eps,C,nu,p,shrinking		if(param.cache_size <= 0)			return "cache_size <= 0";		if(param.eps <= 0)			return "eps <= 0";		if(svm_type == svm_parameter.C_SVC ||		   svm_type == svm_parameter.EPSILON_SVR ||		   svm_type == svm_parameter.NU_SVR)			if(param.C <= 0)				return "C <= 0";		if(svm_type == svm_parameter.NU_SVC ||		   svm_type == svm_parameter.ONE_CLASS ||		   svm_type == svm_parameter.NU_SVR)			if(param.nu <= 0 || param.nu > 1)				return "nu <= 0 or nu > 1";		if(svm_type == svm_parameter.EPSILON_SVR)			if(param.p < 0)				return "p < 0";		if(param.shrinking != 0 &&		   param.shrinking != 1)			return "shrinking != 0 and shrinking != 1";		if(param.probability != 0 &&		   param.probability != 1)			return "probability != 0 and probability != 1";		if(param.probability == 1 &&		   svm_type == svm_parameter.ONE_CLASS)			return "one-class SVM probability output not supported yet";				// check whether nu-svc is feasible			if(svm_type == svm_parameter.NU_SVC)		{			int l = prob.l;			int max_nr_class = 16;			int nr_class = 0;			int[] label = new int[max_nr_class];			int[] count = new int[max_nr_class];			int i;			for(i=0;i<l;i++)			{				int this_label = (int)prob.y[i];				int j;				for(j=0;j<nr_class;j++)					if(this_label == label[j])					{						++count[j];						break;					}				if(j == nr_class)				{					if(nr_class == max_nr_class)					{						max_nr_class *= 2;						int[] new_data = new int[max_nr_class];						System.arraycopy(label,0,new_data,0,label.length);						label = new_data;												new_data = new int[max_nr_class];						System.arraycopy(count,0,new_data,0,count.length);						count = new_data;					}					label[nr_class] = this_label;					count[nr_class] = 1;					++nr_class;				}			}			for(i=0;i<nr_class;i++)			{				int n1 = count[i];				for(int j=i+1;j<nr_class;j++)				{					int n2 = count[j];					if(param.nu*(n1+n2)/2 > Math.min(n1,n2))						return "specified nu is infeasible";				}			}		}		return null;	}	public static int svm_check_probability_model(svm_model model)	{		if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) &&		model.probA!=null && model.probB!=null) ||		((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) &&		 model.probA!=null))			return 1;		else			return 0;	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -