📄 classificationsvm.cpp
字号:
/* CSvm类实现支持向量机算法(SVM),调用了libSVM开源程序。
其中“paramGamma”和“paramC”是算法中的两个超参数。
使用方法如下:创建一个对象,然后设置超参数,调用train方法训练分类器,调用test方法进行分类。 */
#include "stdlib.h"
#include <string.h>
#include <fstream>
#include <math.h>
#include <iostream>
using namespace std;
#include "global.h"
#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
CSvm::CSvm()
{
paramGamma = 1;
paramC = 1;
resultTotal = NULL;
resultRight = NULL;
x_space = NULL;
}
CSvm::~CSvm()
{
if (this->resultTotal != NULL)
delete[] this->resultTotal;
if (this->resultRight != NULL)
delete[] this->resultRight;
svm_destroy_model(model);
if (this->x_space != NULL)
free(this->x_space);
}
/* 使用分类器进行分类,输入测试样本的文件名称,返回true表示程序执行正常,false表示程序执行错误。 */
bool CSvm::train(char* fileName)
{
bool ret = true;
ret = this->sdata.readFile(fileName);//读入训练数据
if (!ret)
return false;//如果数据格式不正确,退出程序。
this->sdata.normalize();
int n=0;
int i, j, k, index;
struct svm_parameter param;
struct svm_problem prob;
const char *error_msg;
int numFeature = this->sdata.numFeature;
int numClass = this->sdata.numClass;
//设置svm参数
param.svm_type = C_SVC;
param.kernel_type = RBF;
param.degree = 3;
param.gamma = this->paramGamma/((double)numFeature);//150.0/(double)numFeature; // 1/k letter:130 digital:20 satimage:250 iris:20 wave21:200 wave40:150
param.coef0 = 0;
param.nu = 0.5;
param.cache_size = 200;
param.C = this->paramC;
param.eps = 1e-3;
param.p = 0.1;
param.shrinking = 1;
param.probability = 0;
param.nr_weight = 0;
param.weight_label = NULL;
param.weight = NULL;
//设置svm的训练数据
prob.l = this->sdata.numSample;
prob.y = Malloc(double, prob.l);
prob.x = Malloc(struct svm_node *, prob.l);
this->x_space = Malloc(struct svm_node, prob.l*(numFeature+1));
for(i=0;i<prob.l;i++)
{
index = i*(numFeature+1);
prob.x[i] = &(this->x_space[index]);
prob.y[i] = (double)this->sdata.ydata[i];
for (j=0; j<numFeature; j++)
{
this->x_space[index+j].index = j+1;
this->x_space[index+j].value = this->sdata.xdata[i*numFeature+j];
}
this->x_space[index+numFeature].index = -1;
}
//开始训练
error_msg = svm_check_parameter(&prob,¶m);
if(error_msg)
{
cout<<"Error:"<<error_msg<<endl;
}
else
{
this->model = svm_train(&prob, ¶m);
}
svm_destroy_param(¶m);
free(prob.y);
free(prob.x);
return ret;
}
/* 使用分类器进行分类,输入测试样本的文件名称,返回true表示程序执行正常,false表示程序执行错误。 */
bool CSvm::test(char* fileName)
{
bool ret = true;
ifstream ifs(fileName);
int bufSize = 1024;
char* buf=new char[bufSize];
char* temp=new char[bufSize];
int numFeature = this->sdata.numFeature;
int numClass = this->sdata.numClass;
int n=0;
int i, j;
int cnum;
int v;
struct svm_node *x;
if (NULL==this->resultTotal)
this->resultTotal = new DOUBLE[numClass];
if (NULL==this->resultRight)
this->resultRight = new DOUBLE[numClass];
for (i=0; i<numClass; i++)
{
this->resultTotal[i] = 0;
this->resultRight[i] = 0;
}
x = Malloc(struct svm_node, numFeature+1);
while (ifs.good()) {
ifs.getline(buf, bufSize);
n=0;
while (*(buf+n)!=',' && n<bufSize)
{
*(temp+n)=*(buf+n);
n++;
}
*(temp+n)=0;
cnum=this->sdata.searchClassName(temp);
if (cnum>-1)
{
this->resultTotal[cnum] += 1;
for (j=0; j<numFeature; j++)
{
n++;
int n2=n;
while (*(buf+n)!=',' && *(buf+n)!=0 && n<bufSize)
{
*(temp+n-n2)=*(buf+n);
n++;
}
*(temp+n-n2)=0;
x[j].index = j+1;
x[j].value = atof(temp);
if (this->sdata.maxValue[j] > this->sdata.minValue[j])
{
x[j].value = (x[j].value - this->sdata.minValue[j])/(this->sdata.maxValue[j] - this->sdata.minValue[j]);
}
else
x[j].value = 0.5;
}
x[numFeature].index = -1;
//测试
v = (int)svm_predict(this->model, x);
if (v == cnum)
this->resultRight[cnum] += 1;
}
}
return ret;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -