📄 classificationknn.cpp
字号:
/* CKnn类实现K近邻分类器算法(KNN),其中“paramK”是指算法中的参数K,默认值为1(最近邻分类器)。
使用方法如下:创建一个对象,然后设置paramK参数,调用train方法训练分类器,调用test方法进行分类。 */
#include "stdlib.h"
#include <iostream>
#include <string.h>
#include <fstream>
#include <math.h>
using namespace std;
#include "global.h"
CKnn::CKnn()
{
paramK = 1;
resultTotal = NULL;
resultRight = NULL;
}
CKnn::~CKnn()
{
if (this->resultTotal != NULL)
delete[] this->resultTotal;
if (this->resultRight != NULL)
delete[] this->resultRight;
}
/* 训练分类器,输入训练样本的文件名,返回true表示训练成功,返回false表示训练失败。 */
bool CKnn::train(char* fileName)
{
bool ret = true;
ret = this->sdata.readFile(fileName);//读入训练数据
if (!ret)
return false;//如果数据格式不正确,退出程序。
this->sdata.normalize();
return ret;
}
/* 使用分类器进行分类,输入测试样本的文件名称,返回true表示程序执行正常,false表示程序执行错误。 */
bool CKnn::test(char* fileName)
{
bool ret = true;
ifstream ifs(fileName);
int bufSize = 1024;
char* buf=new char[bufSize];
char* temp=new char[bufSize];
int numClass = this->sdata.numClass;
int numFeature = this->sdata.numFeature;
int numSample = this->sdata.numSample;
DOUBLE* sampleValue=new DOUBLE[numFeature];
int* value = new int[numClass];
int n=0;
int i, j;
int cnum;
int max, index;
DOUBLE distence, tempValue;
DOUBLE* minValue = new DOUBLE[this->paramK];
int* className = new int[this->paramK];
if (NULL==this->resultTotal)
this->resultTotal = new DOUBLE[numClass];
if (NULL==this->resultRight)
this->resultRight = new DOUBLE[numClass];
for (i=0; i<numClass; i++)
{
this->resultTotal[i] = 0;
this->resultRight[i] = 0;
}
while (ifs.good()) {
ifs.getline(buf, bufSize);
n=0;
while (*(buf+n)!=',' && n<bufSize)
{
*(temp+n)=*(buf+n);
n++;
}
*(temp+n)=0;
cnum=this->sdata.searchClassName(temp);
if (cnum>-1)
{
this->resultTotal[cnum] += 1;
for (j=0; j<numFeature; j++)
{
n++;
int n2=n;
while (*(buf+n)!=',' && n<bufSize)
{
*(temp+n-n2)=*(buf+n);
n++;
}
*(temp+n-n2)=0;
sampleValue[j]=atof(temp);
if (this->sdata.maxValue[j] > this->sdata.minValue[j])
{
*(sampleValue+j) = (*(sampleValue+j) - this->sdata.minValue[j])/(this->sdata.maxValue[j] - this->sdata.minValue[j]);
}
else
*(sampleValue+j) = 0.5;
}
for (i=0; i<this->paramK; i++)
{
distence = 0;
for (j=0; j<numFeature; j++)
{
tempValue = sampleValue[j]-this->sdata.xdata[i*numFeature+j];
distence += tempValue * tempValue;
}
minValue[i] = distence;
className[i] = this->sdata.ydata[i];
}
//从小到大排序
for (i=0; i<this->paramK-1; i++)
{
for (j=this->paramK-1; j>i; j--)
{
if (minValue[j] < minValue[j-1])
{
tempValue = minValue[j];
minValue[j] = minValue[j-1];
minValue[j-1] = tempValue;
n = className[j];
className[j] = className[j-1];
className[j-1] = n;
}
}
}
for (i=this->paramK; i<numSample; i++)
{
distence = 0;
for (j=0; j<numFeature; j++)
{
tempValue = sampleValue[j]-this->sdata.xdata[i*numFeature+j];
distence += tempValue * tempValue;
}
if (distence < minValue[this->paramK-1])
{
minValue[this->paramK-1] = distence;
className[this->paramK-1] = this->sdata.ydata[i];
//重新排序
for (j=this->paramK-1; j>0; j--)
{
if (minValue[j] < minValue[j-1])
{
tempValue = minValue[j];
minValue[j] = minValue[j-1];
minValue[j-1] = tempValue;
n = className[j];
className[j] = className[j-1];
className[j-1] = n;
}
}
}
}
for (i=0; i<numClass; i++)
value[i] = 0;
for (i=0; i<this->paramK; i++)
value[className[i]] ++;
max = 0;
for (i=0; i<numClass; i++)
{
if (value[i]>max)
{
max = value[i];
index = i;
}
}
if (index == cnum)
this->resultRight[cnum] += 1;
}
}
return ret;
}
DOUBLE CKnn::crossValidation(char* fileName)
{
int fold = 4;
if (this->sdata.ydata == NULL)
this->train(fileName);
int numClass = this->sdata.numClass;
int numFeature = this->sdata.numFeature;
int numTotalSample = this->sdata.numSample;
int numSample = this->sdata.numSample/fold;
DOUBLE* sampleValue=new DOUBLE[numFeature];
int* value = new int[numClass];
int n1, n2, n3, i, j;
int n, cnum;
int max, index;
DOUBLE distence, tempValue;
DOUBLE* minValue = new DOUBLE[this->paramK];
int* className = new int[this->paramK];
int resultTotal = numSample * fold;
int resultRight = 0;
DOUBLE infinite = 1e15;
for (n1=0; n1<fold; n1++)//n1 测试集标识
{
for (n3=0; n3<numSample; n3++)
{
for (j=0; j<numFeature; j++)
{
sampleValue[j] = this->sdata.xdata[(n1*numSample+n3)*numFeature+j];
}
cnum = this->sdata.ydata[n1*numSample+n3];
for (j=0; j<this->paramK; j++)
{
minValue[j] = infinite;
className[j] = 0;
}
for (n2=0; n2<fold; n2++)//n2 训练集标识
{
if (n2!=n1)
{
for (i=0; i<numSample; i++)
{
distence = 0;
for (j=0; j<numFeature; j++)
{
tempValue = sampleValue[j]-this->sdata.xdata[(n2*numSample+i)*numFeature+j];
distence += tempValue * tempValue;
}
if (distence < minValue[this->paramK-1])
{
minValue[this->paramK-1] = distence;
className[this->paramK-1] = this->sdata.ydata[n2*numSample+i];
//重新排序
for (j=this->paramK-1; j>0; j--)
{
if (minValue[j] < minValue[j-1])
{
tempValue = minValue[j];
minValue[j] = minValue[j-1];
minValue[j-1] = tempValue;
n = className[j];
className[j] = className[j-1];
className[j-1] = n;
}
}
}
}
}
}
for (i=0; i<numClass; i++)
value[i] = 0;
for (i=0; i<this->paramK; i++)
value[className[i]] ++;
max = 0;
for (i=0; i<numClass; i++)
{
if (value[i]>max)
{
max = value[i];
index = i;
}
}
if (index == cnum)
resultRight ++;
}
}
DOUBLE ret = ((DOUBLE)resultRight)/((DOUBLE)resultTotal);
return ret;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -