⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.cpp

📁 支持向量机的SMO算法VC实现,在模式识别和 许多方面很有用
💻 CPP
字号:
/******************************************************
 * 基于SMO算法的支持向量机
 * 详情请见D:\下载资料\计算语言学\支持向量机\smo.pdf
 *****************************************************/
#include "smo.h"

//全局变量
int N=0;
int d=-1;
double C=0.05;
double tolerance=0.001;
double eps=0.001;
double two_sigma_squared=2;
double delta_b=0;

vector<double> alph;             //lagrange 乘因子
double b;
vector<double> w;                   //权系数向量w,仅用于线性的核函数

vector<double> error_cache;



//存储数据,只使用其中一个变量
vector<sparse_binary_vector> sparse_binary_points;
vector<sparse_vector> sparse_points;
vector<dense_vector> dense_points;

vector<int> target;   //训练数据的分类标签
bool is_sparse_data=false;
bool is_binary=false;
bool is_test_only=false;
bool is_linear_kernel=false;

int first_test_i=0;
int end_support_i=-1;
vector<double> precomputed_self_dot_product;

double (*dot_product_func)(int,int)=NULL;//计算两个样本之间的点积
double (*learned_func)(int)=NULL;//学习函数
double (*kernel_func)(int,int)=NULL;//核函数
 
int examineExample(int i1)
{
	double y1,alph1,E1,r1;
	y1=target[i1];
	alph1=alph[i1];
	if(alph1>0&&alph1<C)
		E1=error_cache[i1];
	else
		E1=learned_func(i1)-y1;

	r1=y1*E1;
	if((r1-tolerance&&alph1<C)||(r1>tolerance&&alph1>0))//不满足KKT条件
	{
		//寻找第二个权值更新,并返回
		//寻找|E1-E2|最大的,18b
		{
			int k,i2;
			double tmax;
			for(i2=-1,tmax=0,k=0;k<end_support_i;k++)
				if(alph[k]>0&&alph[k]<C)
				{
					double E2,temp;					
					E2=error_cache[k];
					temp=fabs(E1-E2);
					if(temp>tmax)
					{
						tmax=temp;
						i2=k;
					}
				}
				if(i2>=0)
				{
					if(takeStep(i1,i2))
						return 1;
				}
		}
		//寻找非边界样本19b
		{
			int k,k0;
			int i2;
			for(k0=(int)((rand()/RAND_MAX)*end_support_i),k=k0;k<end_support_i+k0;k++)
			{
				i2=k%end_support_i;
				if(alph[i2]>0&&alph[i2]<C)
				{
					if(takeStep(i1,i2))
						return 1;
				}
			}
		}
		//寻找整个样本19c
		{
			int k0,k,i2;
			for(k0=(int)((rand()/RAND_MAX)*end_support_i),k=k0;k<end_support_i+k0;k++)
			{
				i2=k%end_support_i;
				if(takeStep(i1,i2))
					return 1;
			}
		}

	}
	return 0;
}

int takeStep(int i1,int i2)
{
	int y1,y2,s;
	double alph1,alph2;//旧的权值
	double a1,a2;  //新的权值
	double E1,E2,L,H,k11,k22,k12,eta,lobj,hobj;

	if(i1==i2)
		return 0;
	//21a
	alph1=alph[i1];
	y1=target[i1];
	if(alph1>0&&alph1<C)
		E1=error_cache[i1];
	else
		E1=learned_func(i1)-y1;
	alph2=alph[i2];
	y2=target[i2];
	if(alph2>0&&alph2<C)
		E2=error_cache[i2];
	else
		E2=learned_func(i2)-y2;


	s=y1*y2;
	//22a
	if(y1==y2)
	{
		double gamma=alph1+alph2;
		if(gamma>C)
		{
			L=gamma-C;
			H=C;
		}
		else
		{
			L=0;
			H=gamma;
		}
	}
	else
	{
		double gamma=alph1-alph2;
		if(gamma>0)
		{
			L=0;H=C-gamma;
		}
		else
		{
			L=-gamma;H=C;
		}
	}

	if(L==H)
		return 0;
	//22b
	k11=kernel_func(i1,i1);
	k12=kernel_func(i1,i2);
	k22=kernel_func(i2,i2);
	eta=2*k12-k11-k12;

	if(eta<0)
	{
		a2=alph2+y2*(E2-E1)/eta;
		if(a2<L)
			a2=L;
		else if(a2>H)
			a2=H;
	}
	else
	{
		//22d
		double c1=eta/2;
		double c2=y2*(E1-E2)-eta*alph2;
		lobj=c1*L*L+c2*L;
		hobj=c1*H*H+c2*H;
	
		if(lobj>hobj+eps)
			a2=L;
		else if(lobj<hobj-eps)
			a2=H;
		else
			a2=alph2;
	}
	if(fabs(a2-alph2)<eps*(a2+alph2+eps))
		return 0;
	a1=alph1-s*(a2-alph2);
	if(a1<0)
	{
		a2+=s*a1;
		a1=0;
	}
	else if(a1>C)
	{
		double t=a1-C;
		a2+=s*t;
		a1=C;
	}
	//更新b 23a
	{
		double b1,b2,bnew;
		if(a1>0&&a1<C)
			bnew=b+E1+y1*(a1-alph1)*k11+y2*(a2-alph2)*k12;
		else
		{
			if(a2>0&&a2<C)
				bnew=b+E2+y1*(a1-alph1)*k12+y2*(a2-alph2)*k22;
			else
			{
				b1=b+E1+y1*(a1-alph1)*k11+y2*(a2-alph2)*k12;
				b2=b+E2+y1*(a1-alph1)*k12+y2*(a2-alph2)*k22;
				bnew=(b1+b2)/2;
			}
		}
		delta_b=bnew-b;
		b=bnew;
	}
	//如果使用线性的核函数,需要更新权向量 23c
	if(is_linear_kernel)
	{
		double t1=y1*(a1-alph1);
		double t2=y2*(a2-alph2);
		if(is_sparse_data&&is_binary)
		{
			int p1,num1,p2,num2;
			num1=(int)sparse_binary_points[i1].id.size();
			for(p1=0;p1<num1;p1++)
				w[sparse_binary_points[i1].id[p1]]+=t1;
			num2=(int)sparse_binary_points[i2].id.size();
			for(p2=0;p2<num2;p2++)
				w[sparse_binary_points[i2].id[p2]]+=t2;
		}
		else if(is_sparse_data&&!is_binary)
		{
			int p1,num1,p2,num2;
			num1=(int)sparse_points[i1].id.size();
			for(p1=0;p1<num1;p1++)
				w[sparse_points[i1].id[p1]]+=t1*sparse_points[i1].val[p1];
			num2=(int)sparse_points[i2].id.size();
			for(p2=0;p2<num2;p2++)
				w[sparse_points[i2].id[p2]]+=t2*sparse_points[i2].val[p2];
		}
		else
			for(int i=0;i<d;i++)
				w[i]+=dense_points[i1][i]*t1+dense_points[i2][i]*t2;
	}
	//更新错误率 24a
	{
		double t1=y1*(a1-alph1);
		double t2=y2*(a2-alph2);
		for(int i=0;i<end_support_i;i++)
			if(0<alph[i]&&alph[i]<C)
				error_cache[i]+=t1*kernel_func(i1,i)+t2*kernel_func(i2,i)-delta_b;
		error_cache[i1]=0;
		error_cache[i2]=0;
	}
	alph[i1]=a1;
	alph[i2]=a2;
	return 1;
}

double learned_func_linear_sparse_binary(int k)
{
	double s=0;
	for(int i=0;i<(int)sparse_binary_points[k].id.size();i++)
		s+=w[sparse_binary_points[k].id[i]];
	s-=b;
	return s;
}
double learned_func_linear_sparse_nobinary(int k)
{
	double s=0;
	for(int i=0;i<(int)sparse_points[k].id.size();i++)
	{
		int j=sparse_points[k].id[i];
		double v=sparse_points[k].val[i];
		s+=w[j]*v;
	}
	s-=b;
	return s;
}
double learned_func_linear_dense(int k)
{
	double s=0;
	for(int i=0;i<d;i++)
		s+=w[i]*dense_points[k][i];
	s-=b;
	return s;
}
double learned_func_nonlinear(int k)
{
	double s=0;
	for(int i=0;i<end_support_i;i++)
		if(alph[i]>0)
			s+=alph[i]*target[i]*kernel_func(i,k);
	s-=b;
	return s;
}
double dot_product_sparse_binary(int i1,int i2)
{
	int p1=0,p2=0,dot=0;
	int num1=(int)sparse_binary_points[i1].id.size();
	int num2=(int)sparse_binary_points[i2].id.size();
	while(p1<num1&&p2<num2)
	{
		int a1=(int)sparse_binary_points[i1].id[p1];
		int a2=(int)sparse_binary_points[i2].id[p2];
		if(a1==a2)
		{
			dot++;p1++;p2++;
		}
		else if(a1>a2)
			p2++;
		else
			p1++;
	}
	return (double)dot;
}
double dot_product_sparse_nonbinary(int i1,int i2)
{
	int p1=0,p2=0;
	double dot=0;
	int num1=(int)sparse_points[i1].id.size();
	int num2=(int)sparse_points[i2].id.size();
	while(p1<num1&&p2<num2)
	{
		int a1=sparse_points[i1].id[p1];
		int a2=sparse_points[i2].id[p2];
		if(a1==a2)
		{
			dot+=sparse_points[i1].val[p1]*sparse_points[i2].val[p2];
			p1++;
			p2++;
		}
		else if(a1>a2)
			p2++;
		else
			p1++;
	}
	return (double)dot;
}
double dot_product_dense(int i1,int i2)
{
	double dot=0;
	for(int i=0;i<d;i++)
		dot+=dense_points[i1][i]*dense_points[i2][i];
	return dot;
}
double rbf_kernel(int i1,int i2)
{
	double s=dot_product_func(i1,i2);
	s*=-2;
	s+=precomputed_self_dot_product[i1]+precomputed_self_dot_product[i2];
	return exp(-s/two_sigma_squared);
}
int read_data(istream& is)
{
	string s;
	int n_lines;
	for(n_lines=0;getline(is,s,'\n');n_lines++)
	{
		istrstream line(s.c_str());
		vector<double> v;
		double t;
		while(line>>t)
			v.push_back(t);
		target.push_back((int)v.back());
		v.pop_back();
		int n=(int)v.size();
		if(is_sparse_data&&is_binary)
		{
			sparse_binary_vector x;
			for(int i=0;i<n;i++)
			{
				if(v[i]<1||v[i]>d)
				{
#ifdef INFO
					cout<<"error:line"<<n_lines+1<<":attribute_index"<<int(v[i])<<"out of range."<<endl;
#endif
					return 0;
				}
				x.id.push_back(int(v[i])-1);
			}
		}
		else if (is_sparse_data&&!is_binary)
		{
			sparse_vector x;
			for(int i=0;i<n;i+=2)
			{
				if(v[i]<1||v[i]>d)
				{
#ifdef INFO
					cout<<"data file error:line"<<n_lines+1<<":attribute index "<<int(v[i])<<" out of range."<<endl;
#endif
					return 0;
				}
				x.id.push_back(int(v[i])-1);
				x.val.push_back(v[i+1]);
			}
			sparse_points.push_back(x);
		}
	else
	{
		if(v.size()!=d)
		{
#ifdef INFO
			cout<<"data file error:line "<<n_lines+1<<" has "<<(int)v.size()<<" attributes;should be d="<<d<<endl;
#endif
			return 0;
		}
		dense_points.push_back(v);
	}
	}
	return n_lines;
}
void write_svm(ostream& os)
{
	os<<d<<endl;
	os<<is_sparse_data<<endl;
	os<<is_binary<<endl;
	os<<is_linear_kernel<<endl;
	os<<b<<endl;
	if(is_linear_kernel)
	{
		for(int i=0;i<d;i++)
			os<<w[i]<<endl;
	}
	else
	{
		os<<two_sigma_squared<<endl;
		int n_support_vectors=0;
		for(int i=0;i<end_support_i;i++)
			if(alph[i]>0)
				n_support_vectors++;
		os<<n_support_vectors<<endl;
		for(int i=0;i<end_support_i;i++)
			if(alph[i]>0)
				os<<alph[i]<<endl;
		for(int i=0;i<end_support_i;i++)
			if(alph[i]>0)
			{
				if(is_sparse_data&&is_binary)
				{
					for(int j=0;j<(int)sparse_binary_points[i].id.size();j++)
						os<<(sparse_binary_points[i].id[j]+1)<<' ';
				}
				else if(is_sparse_data&&!is_binary)
				{
					for(int j=0;j<(int)sparse_points[i].id.size();j++)
						os<<(sparse_points[i].id[j]+1)<<' '<<sparse_points[i].val[j]<<' ';
				}
				else
				{
					for(int j=0;j<d;j++)
						os<<dense_points[i][j]<<' ';
				}
				os<<target[i];
				os<<endl;
			}
	}
}
int read_svm(istream& is)
{
	is>>d;
	is>>is_sparse_data;
	is>>is_binary;
	is>>is_linear_kernel;
	is>>b;
	if(is_linear_kernel)
	{
		w.resize(d);
		for(int i=0;i<d;i++)
			is>>w[i];
	}
	else
	{
		is>>two_sigma_squared;
		int n_support_vectors;
		is>>n_support_vectors;
		alph.resize(n_support_vectors,0);
		for(int i=0;i<n_support_vectors;i++)
			is>>alph[i];
		string dummy_line_to_skip_newline;
		getline(is,dummy_line_to_skip_newline,'\n');
		return read_data(is);
	}
	return 0;
}
double error_rate()
{
	int n_total=0;
	int n_error=0;
	for(int i=first_test_i;i<N;i++)
	{
		if(learned_func(i)>0!=target[i]>0)
			n_error++;
		n_total++;
	}
	return double(n_error)/double(n_total);
}
int smo(string data_file_name,string svm_file_name)
{
	//31a
	string output_file_name;
	int numChanged;
	int examineAll;
	//获得参数29d
	N=0;//训练样本的总数
	d=2;//样本空间的维数
	C=0.01;//惩罚因子
	tolerance=0.001;//满足KKT条件的容忍度
	eps=0.001;
	two_sigma_squared=2;//径向基核函数的参数
	data_file_name="svm.data";//数据文件
	svm_file_name="svm.model";//模型文件
	output_file_name="svm.output";//输出文件
	is_linear_kernel=false;//是否是线性的核函数
	is_sparse_data=false;//是否是稀疏数据
	is_binary=false;//是否是二进制数据
	is_test_only=false;
	
	//读入数据31c
	{
		int n;
		if(is_test_only)
		{
			ifstream svm_file(svm_file_name.c_str());
			end_support_i=first_test_i=n=read_svm(svm_file);
			N+=n;
		}
		if(N>0)
		{
			target.reserve(N);
			if(is_sparse_data&&is_binary)
				sparse_binary_points.reserve(N);
			else if(is_sparse_data&&!is_binary)
				sparse_points.reserve(N);
			else
				dense_points.reserve(N);
		}
		ifstream data_file(data_file_name.c_str());
		if(!data_file.is_open())
			return 1;
		n=read_data(data_file);
		if(n<=0)
			return 2;
		if(is_test_only)
		{
			N=first_test_i+n;
		}
		else
		{
			N=n;
			first_test_i=0;
			end_support_i=N;
		}
	}
	if(!is_test_only)
	{
		alph.resize(end_support_i,0.0);
		b=0;
		error_cache.resize(N);
		if(is_linear_kernel)
			w.resize(d,0.0);
	}

	//初始化学习函数,点积和核函数 26a
	if(is_linear_kernel&&is_sparse_data&&is_binary)
		learned_func=learned_func_linear_sparse_binary;
	if(is_linear_kernel&&is_sparse_data&&!is_binary)
		learned_func=learned_func_linear_sparse_nobinary;
	if(is_linear_kernel&&!is_sparse_data)
		learned_func=learned_func_linear_dense;
	if(!is_linear_kernel)
		learned_func=learned_func_nonlinear;
	if(is_sparse_data&&is_binary)
		dot_product_func=dot_product_sparse_binary;
	if(is_sparse_data&&!is_binary)
		dot_product_func=dot_product_sparse_nonbinary;
	if(!is_sparse_data)
		dot_product_func=dot_product_dense;
	if(is_linear_kernel)
		kernel_func=dot_product_func;
	if(!is_linear_kernel)
		kernel_func=rbf_kernel;
	if(!is_linear_kernel)
	{
		precomputed_self_dot_product.resize(N);
		for(int i=0;i<N;i++)
			precomputed_self_dot_product[i]=dot_product_func(i,i);
	}

	if(!is_test_only)
	{
		numChanged=0;
		examineAll=1;
		while(numChanged>0||examineAll)
		{
			numChanged=0;
			if(examineAll)
			{
				for(int k=0;k<N;k++)
					numChanged+=examineExample(k);
			}
			else
			{
				for(int k=0;k<N;k++)
					if(alph[k]!=0&&alph[k]!=C)
						numChanged+=examineExample(k);
			}
			if(examineAll==1)
				examineAll=0;
			else if(numChanged==0)
				examineAll=1;
			//诊断信息36d
		}
		//输出模型参数36a
		{
			if((!is_test_only)&&(!svm_file_name.empty()))
			{
				ofstream svm_file(svm_file_name.c_str());
				write_svm(svm_file);
			}
		}
#ifdef INFO
		cout<<"threshold="<<b<<endl;
#endif
	}
#ifdef INFO
	cout<<"训练完毕,错误率为: "<<error_rate()<<endl;
#endif
	//输出分类36c
	return 0;
}
int load_svm(string svm_file_name)
{
	int n;
	//初始化学习函数和核函数

	N=0;
	d=2;//样本空间的维数
	C=0.01;//惩罚因子
	tolerance=0.001;//满足KKT条件的容忍度
	eps=0.001;
	ifstream svm_file(svm_file_name.c_str());
	if(!svm_file.is_open())
		return 1;
	end_support_i=first_test_i=n=read_svm(svm_file);
	if(n<=0)
		return 2;
	N+=n;
	return 0;
}
double predict_func(const vector<double>& vx)
{
	double s=0;
	for(int i=0;i<end_support_i;i++)
		if(alph[i]>0)
			s+=alph[i]*target[i]*kernel(dense_points[i],vx);
	s-=b;
	return s;
}
double kernel(const vector<double>& vx1,const vector<double>& vx2)
{
	if(vx1.size()!=vx2.size())
		return 0;
	double dot=0;
	for(int j=0;j<d;j++)
		dot+=(vx1[j]-vx2[j])*(vx1[j]-vx2[j]);
	dot/=2;
	return -dot;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -