⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classificationsvm.cpp

📁 该程序包实现了几个常用的模式识别分类器算法
💻 CPP
字号:
/* CSvm类实现支持向量机算法(SVM),调用了libSVM开源程序。
	其中“paramGamma”和“paramC”是算法中的两个超参数。
   使用方法如下:创建一个对象,然后设置超参数,调用train方法训练分类器,调用test方法进行分类。 */
#include "stdlib.h"
#include <string.h>
#include <fstream>
#include <math.h>
#include <iostream>
using namespace std;
#include "global.h"
#define Malloc(type,n) (type *)malloc((n)*sizeof(type))

CSvm::CSvm()
{
	paramGamma = 1;
	paramC = 1;
	resultTotal = NULL;
	resultRight = NULL;
	x_space = NULL;
}

CSvm::~CSvm()
{
	if (this->resultTotal != NULL)
		delete[] this->resultTotal;
	if (this->resultRight != NULL)
		delete[] this->resultRight;
	svm_destroy_model(model);
	if (this->x_space != NULL)
		free(this->x_space);
}

/* 使用分类器进行分类,输入测试样本的文件名称,返回true表示程序执行正常,false表示程序执行错误。 */
bool CSvm::train(char* fileName)
{
	bool ret = true;
	ret = this->sdata.readFile(fileName);//读入训练数据
	if (!ret)
		return false;//如果数据格式不正确,退出程序。
	this->sdata.normalize();
	int n=0;
	int i, j, k, index;
	struct svm_parameter param;
	struct svm_problem prob;
	const char *error_msg;
	int numFeature = this->sdata.numFeature;
	int numClass = this->sdata.numClass;

	//设置svm参数
	param.svm_type = C_SVC;
	param.kernel_type = RBF;
	param.degree = 3;
	param.gamma = this->paramGamma/((double)numFeature);//150.0/(double)numFeature;	// 1/k letter:130 digital:20 satimage:250 iris:20 wave21:200 wave40:150
	param.coef0 = 0;
	param.nu = 0.5;
	param.cache_size = 200;
	param.C = this->paramC;
	param.eps = 1e-3;
	param.p = 0.1;
	param.shrinking = 1;
	param.probability = 0;
	param.nr_weight = 0;
	param.weight_label = NULL;
	param.weight = NULL;
	//设置svm的训练数据
	prob.l = this->sdata.numSample;
	prob.y = Malloc(double, prob.l);
	prob.x = Malloc(struct svm_node *, prob.l);
	this->x_space = Malloc(struct svm_node, prob.l*(numFeature+1));
	for(i=0;i<prob.l;i++)
	{
		index = i*(numFeature+1);
		prob.x[i] = &(this->x_space[index]);
		prob.y[i] = (double)this->sdata.ydata[i];
		for (j=0; j<numFeature; j++)
		{
			this->x_space[index+j].index = j+1;
			this->x_space[index+j].value = this->sdata.xdata[i*numFeature+j];
		}
		this->x_space[index+numFeature].index = -1;
	}
	//开始训练
	error_msg = svm_check_parameter(&prob,&param);
	if(error_msg)
	{
		cout<<"Error:"<<error_msg<<endl;
	}
	else
	{
		this->model = svm_train(&prob, &param);
	}
	svm_destroy_param(&param);
	free(prob.y);
	free(prob.x);
	return ret;
}

/* 使用分类器进行分类,输入测试样本的文件名称,返回true表示程序执行正常,false表示程序执行错误。 */
bool CSvm::test(char* fileName)
{
	bool ret = true;
	ifstream ifs(fileName);
	int bufSize = 1024;
	char* buf=new char[bufSize];
	char* temp=new char[bufSize];
	int numFeature = this->sdata.numFeature;
	int numClass = this->sdata.numClass;
	int n=0;
	int i, j;
	int cnum;
	int v;
	struct svm_node *x;
	if (NULL==this->resultTotal)
		this->resultTotal = new DOUBLE[numClass];
	if (NULL==this->resultRight)
		this->resultRight = new DOUBLE[numClass];
	for (i=0; i<numClass; i++)
	{
		this->resultTotal[i] = 0;
		this->resultRight[i] = 0;
	}
	x = Malloc(struct svm_node, numFeature+1);
	while (ifs.good()) {
		ifs.getline(buf, bufSize);
		n=0;
		while (*(buf+n)!=',' && n<bufSize)
		{
			*(temp+n)=*(buf+n);
			n++;
		}
		*(temp+n)=0;
		cnum=this->sdata.searchClassName(temp);
		if (cnum>-1)
		{
			this->resultTotal[cnum] += 1;
			for (j=0; j<numFeature; j++)
			{
				n++;
				int n2=n;
				while (*(buf+n)!=',' && *(buf+n)!=0 && n<bufSize)
				{
					*(temp+n-n2)=*(buf+n);
					n++;
				}
				*(temp+n-n2)=0;
				x[j].index = j+1;
				x[j].value = atof(temp);
				if (this->sdata.maxValue[j] > this->sdata.minValue[j])
				{
					x[j].value = (x[j].value - this->sdata.minValue[j])/(this->sdata.maxValue[j] - this->sdata.minValue[j]);
				}
				else
					x[j].value = 0.5;
			}
			x[numFeature].index = -1;
			//测试
			v = (int)svm_predict(this->model, x);
			if (v == cnum)
				this->resultRight[cnum] += 1;
		}
	}
	return ret;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -