⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classificationknn.cpp

📁 该程序包实现了几个常用的模式识别分类器算法
💻 CPP
字号:
/* CKnn类实现K近邻分类器算法(KNN),其中“paramK”是指算法中的参数K,默认值为1(最近邻分类器)。
   使用方法如下:创建一个对象,然后设置paramK参数,调用train方法训练分类器,调用test方法进行分类。 */
#include "stdlib.h"
#include <iostream>
#include <string.h>
#include <fstream>
#include <math.h>
using namespace std;
#include "global.h"

CKnn::CKnn()
{
	paramK = 1;
	resultTotal = NULL;
	resultRight = NULL;
}

CKnn::~CKnn()
{
	if (this->resultTotal != NULL)
		delete[] this->resultTotal;
	if (this->resultRight != NULL)
		delete[] this->resultRight;
}

/* 训练分类器,输入训练样本的文件名,返回true表示训练成功,返回false表示训练失败。 */
bool CKnn::train(char* fileName)
{
	bool ret = true;
	ret = this->sdata.readFile(fileName);//读入训练数据
	if (!ret)
		return false;//如果数据格式不正确,退出程序。
	this->sdata.normalize();
	return ret;
}

/* 使用分类器进行分类,输入测试样本的文件名称,返回true表示程序执行正常,false表示程序执行错误。 */
bool CKnn::test(char* fileName)
{
	bool ret = true;
	ifstream ifs(fileName);
	int bufSize = 1024;
	char* buf=new char[bufSize];
	char* temp=new char[bufSize];
	int numClass = this->sdata.numClass;
	int numFeature = this->sdata.numFeature;
	int numSample = this->sdata.numSample;
	DOUBLE* sampleValue=new DOUBLE[numFeature];
	int* value = new int[numClass];
	int n=0;
	int i, j;
	int cnum;
	int max, index;
	DOUBLE distence, tempValue;
	DOUBLE* minValue = new DOUBLE[this->paramK];
	int* className = new int[this->paramK];

	if (NULL==this->resultTotal)
		this->resultTotal = new DOUBLE[numClass];
	if (NULL==this->resultRight)
		this->resultRight = new DOUBLE[numClass];
	for (i=0; i<numClass; i++)
	{
		this->resultTotal[i] = 0;
		this->resultRight[i] = 0;
	}
	while (ifs.good()) {
		ifs.getline(buf, bufSize);
		n=0;
		while (*(buf+n)!=',' && n<bufSize)
		{
			*(temp+n)=*(buf+n);
			n++;
		}
		*(temp+n)=0;
		cnum=this->sdata.searchClassName(temp);
		if (cnum>-1)
		{
			this->resultTotal[cnum] += 1;
			for (j=0; j<numFeature; j++)
			{
				n++;
				int n2=n;
				while (*(buf+n)!=',' && n<bufSize)
				{
					*(temp+n-n2)=*(buf+n);
					n++;
				}
				*(temp+n-n2)=0;
				sampleValue[j]=atof(temp);
				if (this->sdata.maxValue[j] > this->sdata.minValue[j])
				{
					*(sampleValue+j) = (*(sampleValue+j) - this->sdata.minValue[j])/(this->sdata.maxValue[j] - this->sdata.minValue[j]);
				}
				else
					*(sampleValue+j) = 0.5;
			}
			for (i=0; i<this->paramK; i++)
			{
				distence = 0;
				for (j=0; j<numFeature; j++)
				{
					tempValue = sampleValue[j]-this->sdata.xdata[i*numFeature+j];
					distence += tempValue * tempValue;
				}
				minValue[i] = distence;
				className[i] = this->sdata.ydata[i];
			}
			//从小到大排序
			for (i=0; i<this->paramK-1; i++)
			{
				for (j=this->paramK-1; j>i; j--)
				{
					if (minValue[j] < minValue[j-1])
					{
						tempValue = minValue[j];
						minValue[j] = minValue[j-1];
						minValue[j-1] = tempValue;
						n = className[j];
						className[j] = className[j-1];
						className[j-1] = n;
					}
				}
			}
			for (i=this->paramK; i<numSample; i++)
			{
				distence = 0;
				for (j=0; j<numFeature; j++)
				{
					tempValue = sampleValue[j]-this->sdata.xdata[i*numFeature+j];
					distence += tempValue * tempValue;
				}
				if (distence < minValue[this->paramK-1])
				{
					minValue[this->paramK-1] = distence;
					className[this->paramK-1] = this->sdata.ydata[i];
					//重新排序
					for (j=this->paramK-1; j>0; j--)
					{
						if (minValue[j] < minValue[j-1])
						{
							tempValue = minValue[j];
							minValue[j] = minValue[j-1];
							minValue[j-1] = tempValue;
							n = className[j];
							className[j] = className[j-1];
							className[j-1] = n;
						}
					}
				}
			}
			for (i=0; i<numClass; i++)
				value[i] = 0;
			for (i=0; i<this->paramK; i++)
				value[className[i]] ++;
			max = 0;
			for (i=0; i<numClass; i++)
			{
				if (value[i]>max)
				{
					max = value[i];
					index = i;
				}
			}
			if (index == cnum)
				this->resultRight[cnum] += 1;
		}
	}
	return ret;
}

DOUBLE CKnn::crossValidation(char* fileName)
{
	int fold = 4;
	if (this->sdata.ydata == NULL)
		this->train(fileName);
	int numClass = this->sdata.numClass;
	int numFeature = this->sdata.numFeature;
	int numTotalSample = this->sdata.numSample;
	int numSample = this->sdata.numSample/fold;
	DOUBLE* sampleValue=new DOUBLE[numFeature];
	int* value = new int[numClass];
	int n1, n2, n3, i, j;
	int n, cnum;
	int max, index;
	DOUBLE distence, tempValue;
	DOUBLE* minValue = new DOUBLE[this->paramK];
	int* className = new int[this->paramK];
	int resultTotal = numSample * fold; 
	int resultRight = 0;
	DOUBLE infinite = 1e15;

	for (n1=0; n1<fold; n1++)//n1 测试集标识
	{
		for (n3=0; n3<numSample; n3++)
		{
			for (j=0; j<numFeature; j++)
			{
				sampleValue[j] = this->sdata.xdata[(n1*numSample+n3)*numFeature+j];
			}
			cnum = this->sdata.ydata[n1*numSample+n3];
			for (j=0; j<this->paramK; j++)
			{
				minValue[j] = infinite;
				className[j] = 0;
			}
			for (n2=0; n2<fold; n2++)//n2 训练集标识
			{
				if (n2!=n1)
				{
					for (i=0; i<numSample; i++)
					{
						distence = 0;
						for (j=0; j<numFeature; j++)
						{
							tempValue = sampleValue[j]-this->sdata.xdata[(n2*numSample+i)*numFeature+j];
							distence += tempValue * tempValue;
						}
						if (distence < minValue[this->paramK-1])
						{
							minValue[this->paramK-1] = distence;
							className[this->paramK-1] = this->sdata.ydata[n2*numSample+i];
							//重新排序
							for (j=this->paramK-1; j>0; j--)
							{
								if (minValue[j] < minValue[j-1])
								{
									tempValue = minValue[j];
									minValue[j] = minValue[j-1];
									minValue[j-1] = tempValue;
									n = className[j];
									className[j] = className[j-1];
									className[j-1] = n;
								}
							}
						}
					}
				}
			}
			for (i=0; i<numClass; i++)
				value[i] = 0;
			for (i=0; i<this->paramK; i++)
				value[className[i]] ++;
			max = 0;
			for (i=0; i<numClass; i++)
			{
				if (value[i]>max)
				{
					max = value[i];
					index = i;
				}
			}
			if (index == cnum)
				resultRight ++;
		}
	}
	DOUBLE ret = ((DOUBLE)resultRight)/((DOUBLE)resultTotal);
	return ret;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -