⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 pr_exerciseview.cpp

📁 模式识别中的K近邻法和快速K近邻法的VC++实现
💻 CPP
字号:
// PR_exerciseView.cpp : implementation of the CPR_exerciseView class
//
#include  "math.h"

#include "stdafx.h"
#include "PR_exercise.h"

#include "PR_exerciseDoc.h"
#include "PR_exerciseView.h"

#include "TestSet.h"
#include "KNearestCls.h"
#include "FastKNearestCls.h"
#include "sampleSet.h"

#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif

#define MAXDIMENTION 20
#define MAXDATANUM 200
/////////////////////////////////////////////////////////////////////////////
// CPR_exerciseView

IMPLEMENT_DYNCREATE(CPR_exerciseView, CEditView)

BEGIN_MESSAGE_MAP(CPR_exerciseView, CEditView)
	//{{AFX_MSG_MAP(CPR_exerciseView)
	ON_COMMAND(ID_DATA_CREATE, OnDataCreate)
	ON_COMMAND(ID_KNEAREST, OnKnearest)
	ON_COMMAND(ID_FAST_KNEAREST, OnFastKnearest)
	//}}AFX_MSG_MAP
	// Standard printing commands
	ON_COMMAND(ID_FILE_PRINT, CEditView::OnFilePrint)
	ON_COMMAND(ID_FILE_PRINT_DIRECT, CEditView::OnFilePrint)
	ON_COMMAND(ID_FILE_PRINT_PREVIEW, CEditView::OnFilePrintPreview)
END_MESSAGE_MAP()

/////////////////////////////////////////////////////////////////////////////
// CPR_exerciseView construction/destruction


CPR_exerciseView::CPR_exerciseView()
{
	// TODO: add construction code here
	classNum = 0;
	dimen = 0;
	dataTotal = 0;
	dataSamples = "";
	nearestClass = 0;
	kNearestNum = 0;
}

CPR_exerciseView::~CPR_exerciseView()
{
	if(testSample)
		{
			delete[] testSample;
			testSample = 0;
		}
}

BOOL CPR_exerciseView::PreCreateWindow(CREATESTRUCT& cs)
{
	// TODO: Modify the Window class or styles here by modifying
	//  the CREATESTRUCT cs

	BOOL bPreCreated = CEditView::PreCreateWindow(cs);
	cs.style &= ~(ES_AUTOHSCROLL|WS_HSCROLL);	// Enable word-wrapping

	return bPreCreated;
}

/////////////////////////////////////////////////////////////////////////////
// CPR_exerciseView drawing

void CPR_exerciseView::OnDraw(CDC* pDC)
{
	CPR_exerciseDoc* pDoc = GetDocument();
	ASSERT_VALID(pDoc);
	// TODO: add draw code for native data here

}

/////////////////////////////////////////////////////////////////////////////
// CPR_exerciseView printing

BOOL CPR_exerciseView::OnPreparePrinting(CPrintInfo* pInfo)
{
	// default CEditView preparation
	return CEditView::OnPreparePrinting(pInfo);
}

void CPR_exerciseView::OnBeginPrinting(CDC* pDC, CPrintInfo* pInfo)
{
	// Default CEditView begin printing.
	CEditView::OnBeginPrinting(pDC, pInfo);
}

void CPR_exerciseView::OnEndPrinting(CDC* pDC, CPrintInfo* pInfo)
{
	// Default CEditView end printing
	CEditView::OnEndPrinting(pDC, pInfo);
}

/////////////////////////////////////////////////////////////////////////////
// CPR_exerciseView diagnostics

#ifdef _DEBUG
void CPR_exerciseView::AssertValid() const
{
	CEditView::AssertValid();
}

void CPR_exerciseView::Dump(CDumpContext& dc) const
{
	CEditView::Dump(dc);
}

CPR_exerciseDoc* CPR_exerciseView::GetDocument() // non-debug version is inline
{
	ASSERT(m_pDocument->IsKindOf(RUNTIME_CLASS(CPR_exerciseDoc)));
	return (CPR_exerciseDoc*)m_pDocument;
}
#endif //_DEBUG

/////////////////////////////////////////////////////////////////////////////
// CPR_exerciseView message handlers

//按用户设置的要求生成相应的测试数据集
void CPR_exerciseView::OnDataCreate() 
{
	CTestSet ts;

	if(IDOK == ts.DoModal())
	{
		classNum = ts.m_clsNum;
		dataTotal = ts.m_dataTotal;
		dimen = ts.m_dim;

		int i,j,k;
		char buffer[10];

		srand( (unsigned)time( NULL ) );
	
		for(i = 0; i < ts.m_clsNum; i++)
		{
			for(j = 0; j < ts.m_dataTotal; j++)
			{	
				for(k = 0; k < ts.m_dim; k++)
				{
					 samples[j+(i*ts.m_dataTotal)][k] = 
						 float(rand())/RAND_MAX + 0.5 + (5 * i);
					 _gcvt(samples[j+(i*ts.m_dataTotal)][k] ,
						 9, buffer );
					dataSamples += buffer;
					dataSamples += '\t';
				}
				dataSamples += '\n';
			}
		}
		MessageBox(dataSamples);
	/*	CFileDialog fileDlg(FALSE);
		fileDlg.m_ofn.lpstrTitle = "训练数据集保存对话框";
		fileDlg.m_ofn.lpstrFilter = 
			"Text Files(*.txt)\0*.txt\0All Files(*.*)\0*.*\0\0";
		fileDlg.m_ofn.lpstrDefExt = "txt";
		if(IDOK == fileDlg.DoModal())
		{
			CFile file(fileDlg.GetFileName(),
				CFile::modeCreate | CFile::modeWrite);
			file.Write(dataSamples,dataSamples.GetLength());
			file.Close();
		}*/

	}
}

//K近邻分类器实现
void CPR_exerciseView::OnKnearest() 
{
	CKNearestCls kNearest;
	
	if(IDOK == kNearest.DoModal())
	{
		kNearestNum = kNearest.m_kNearestNum;

		int i,j,k,total,refusePoint;

		refusePoint = kNearest.m_refusePoint>((kNearest.m_kNearestNum+1)/classNum)?
			kNearest.m_refusePoint:((kNearest.m_kNearestNum+1)/classNum);
		
		int count = 0;
		float distance = 0.0;

		float* kNearestDist = new float[kNearestNum];
		for (i = 0; i < kNearestNum; i++)
			kNearestDist[i] = float(INT_MAX);

		int* nearestCls = new int[kNearestNum];
		for (i = 0; i < kNearestNum; i++)
			nearestCls[i] = -1;

		//随机生成测试样本
		GenerateTestSample();
		
		//计算K近邻值
		total = classNum * dataTotal;
		for (i = 0 ; i < total; i++)
		{
			for (j = 0 ; j < dimen; j++)
				//distance += fabs(samples[i][j] - testSample[j]);
				distance += (samples[i][j] - testSample[j])*(samples[i][j] - testSample[j]);

			for (j = 0 ; j < kNearestNum; j++)
			{

				if(distance < kNearestDist[j])
				{
					for(k = (kNearestNum - 1); k > j; k--)
					{
						kNearestDist[k] = kNearestDist[k-1];
						nearestCls[k] = nearestCls[k-1];
					}
					kNearestDist[j] = distance;
					nearestCls[j] = i/dataTotal;
					distance = 0.0;
					break;
				}				
			}
			distance = 0.0;
		}

		//输出最终分类结果
		OutputClassifyResult(nearestCls,refusePoint);
	
		//回收内存空间
		if(kNearestDist)
		{
			delete[] kNearestDist;
			kNearestDist = 0;
		}
		if(nearestCls)
		{
			delete[] nearestCls;
			nearestCls = 0;
		}

	}
}

void CPR_exerciseView::OnFastKnearest() 
{
	// TODO: Add your command handler code here
	
	CFastKNearestCls fastKNearest;
	
	if(IDOK == fastKNearest.DoModal())
	{
		kNearestNum = fastKNearest.m_nearestNum;

		int i,j,k,m,n,refusePoint;
		float dis = 0.0;

		//判断用户设置的拒绝参数,如果小于近邻数与训练集类别数的比值,
		//则默认为近邻数与训练集类别数的比值。
		refusePoint = fastKNearest.m_refusePoint>((fastKNearest.m_nearestNum+1)/classNum)?
			fastKNearest.m_refusePoint:((fastKNearest.m_nearestNum+1)/classNum);
		
		//当前x的K个最近邻所属的类别
		int* pCurrentlyNearestCls = new int[kNearestNum]; 
		for (i = 0; i < kNearestNum; i++)
			pCurrentlyNearestCls[i] = -1;
		
		//当前x与K个最近邻的距离:B,按升序存储于数组中。
		float* kNearestDist = new float[kNearestNum];
		for (i = 0; i < kNearestNum; i++)
			kNearestDist[i] = float(INT_MAX);

		//x与各类样本均值的距离:D(x,Mp)
		float* disWithCls = new float[classNum];
		for (i = 0; i < classNum; i++)
			disWithCls[i] = 0.0;

		//各类样本集中样本与它的均值点的最远距离
		float* distance = new float[classNum];
		for (i = 0; i < classNum; i++)
			distance[i] = 0.0;

		//当前x的K个最近邻
		/*float **pCurrentlyNearest = new float[][dimen];  
		for(i = 0; i < kNearestNum; i++)
			for(j = 0; j < dimen; j++)
				pCurrentlyNearest[i][j] = float (INT_MAX);

		//各类样本集的均值点
		float** means;
		for (i = 0; i < classNum; i++)
			for (j = 0; j < dimen; j++)
				means[i][j] = float(INT_MAX);

	//各类样本集中样本与它的均值点的距离
		float** allDistance;
		for (i = 0; i < classNum; i++)
			for (j = 0; j < dataTotal; j++)
				allDistance[i][j] = float(INT_MAX);


		SamplesDecompose(distance,means,allDistance);*/

		//当前x的K个最近邻
		float pCurrentlyNearest[MAXDIMENTION][MAXDIMENTION];  
		for(i = 0; i < kNearestNum; i++)
			for(j = 0; j < dimen; j++)
				pCurrentlyNearest[i][j] = float (INT_MAX);

		//各类样本集的均值点
		float means[MAXDIMENTION][MAXDIMENTION];
		for (i = 0; i < classNum; i++)
			for (j = 0; j < dimen; j++)
				means[i][j] = float(INT_MAX);

	//各类样本集中样本与它的均值点的距离
		float allDistance[MAXDIMENTION][MAXDATANUM];
		for (i = 0; i < classNum; i++)
			for (j = 0; j < dataTotal; j++)
				allDistance[i][j] = float(INT_MAX);

		//临时数组用于存储各维参数值的总和,用于计算各类的均值
		float* tempTotal = new float[dataTotal];
		for(i = 0; i < dataTotal; i++)
			tempTotal[i] = 0;

		//对每一个类别,分别计算它们的样本均值、类中样本到均值点的最大距离
		//以及各个样本到均值点的距离。
		for(i = 0; i < classNum; i++)
		{
			for(j = i*dataTotal; j < (i+1)*dataTotal; j++)
				for(k = 0; k < dimen; k++)
					tempTotal[k] += samples[j][k];
			//计算样本均值:Mp
			for(k = 0; k < dimen; k++)
				means[i][k] = tempTotal[k]/dataTotal;
			//计算各个样本到均值点的距离:D(xi,Mp),并求出最大距离:Rp
			for(j = i*dataTotal; j < (i+1)*dataTotal; j++)
			{
				for(k = 0; k < dimen; k++)
					dis += (samples[j][k]-means[i][k])*(samples[j][k]-means[i][k]);
					
				distance[i] = distance[i]<dis? dis:distance[i];
				allDistance[i][j - i*dataTotal ] = dis;
				dis = 0.0;
			}
					
			for(j = 0; j < dataTotal; j++)
				tempTotal[j] = 0;
		}
		//随机生成一个测试样本
		GenerateTestSample();

		//计算测试样本到各个类别的均值点的距离:D(x,Mp)
		for(i = 0; i <classNum; i++)
			for(j = 0; j < dimen; j++)
				disWithCls[i] += (testSample[j]-means[i][j])*
					(testSample[j]-means[i][j]);

		//对所有的类利用规则1检验
		for(i = 0; i < classNum; i++)
			//if D(x,Mp)<(最大的B + Rp)
			if(disWithCls[i] < (distance[i] + kNearestDist[kNearestNum-1]))
			{	
				//对该类中的全部样本利用规则2检验
			 	for(j = i*dataTotal; j < (i+1)*dataTotal; j++)
					//if D(x,Mp)<( D(xi,Mp)+ Rp)
					if(disWithCls[i] < (allDistance[i][j - i*dataTotal] 
						+ kNearestDist[kNearestNum-1]))
					{
						dis = 0.0;
						//计算D(x,xi)
						for(k = 0; k < dimen; k++)
							dis += (samples[j][k]-testSample[k])*
							(samples[j][k]-testSample[k]);
				
						for (m = 0 ; m < kNearestNum; m++)
						{
							//如果D(x,xi)<B
							if(dis < kNearestDist[m])
							{
								for(n = (kNearestNum - 1); n > m; n--)
								{
									kNearestDist[n] = kNearestDist[n-1];
									pCurrentlyNearestCls[n] = pCurrentlyNearestCls[n-1];
								}
								kNearestDist[m] = dis;
								pCurrentlyNearestCls[m] = i;
								dis = 0.0;
								break;
							}
						}
					}
			}
		//输出分类结果	
		OutputClassifyResult(pCurrentlyNearestCls,refusePoint);

		//回收内存空间
		if(pCurrentlyNearestCls)
		{
			delete[] pCurrentlyNearestCls;
			pCurrentlyNearestCls = 0;
		}
		if(kNearestDist)
		{
			delete[] kNearestDist;
			kNearestDist = 0;
		}
		if(disWithCls)
		{
			delete[] disWithCls;
			disWithCls = 0;
		}
		if(distance)
		{
			delete[] distance;
			distance = 0;
		}
		if(tempTotal)
		{
			delete[] tempTotal;
			tempTotal = 0;
		}

	}
	
}

//随机生成测试样本函数
void CPR_exerciseView::GenerateTestSample()
{
	CString str = "随机生成测试样本为:\n";
	char buffer[10];

	testSample = new float[dimen];

	srand( (unsigned)time( NULL ) );

	//
	for (int i = 0 ; i < dimen; i++)
	{
		testSample[i] = 
			float(rand())/RAND_MAX + rand()%(5*(classNum-1));
		_gcvt(testSample[i] , 9, buffer );
		str += buffer;
		str += '\t';
	}
	MessageBox(str);
}

//样本集分解函数,在此只是简单地按生成的样本集的类别数进行了一级分解
//更合理的分解算法有待进一步完善。
void CPR_exerciseView::SamplesDecompose(float* &distance, float** &means, float** &allDistance)
{
	int i,j,k;
	float dis = 0.0;
	float* tempTotal;
	for(i = 0; i < dataTotal; i++)
		tempTotal[i] = 0;

	for(i = 0; i < classNum; i++)
	{
		for(j = i*dataTotal; j < (i+1)*dataTotal; j++)
			for(k = 0; k < dimen; k++)
				tempTotal[k] += samples[j][k];
		
		for(k = 0; k < dimen; k++)
			means[i][k] = tempTotal[k]/dataTotal;

		for(j = i*dataTotal; j < (i+1)*dataTotal; j++)
		{
			for(k = 0; k < dimen; k++)
				dis += (samples[j][k]-means[i][k])*(samples[j][k]-means[i][k]);
				
			distance[i] = distance[i]<dis? dis:distance[i];
			allDistance[i][j] = dis;
			dis = 0.0;
		}
				
		for(j = 0; j < dataTotal; j++)
			tempTotal[j] = 0;
	}

}

void CPR_exerciseView::OutputClassifyResult(int* &nearestCls, int refusePoint)
{
	int i, count = 0;
	CString str;
	char buffer[10];

	int* neighbors = new int[classNum];
		for (i = 0; i < classNum; i++)
			neighbors[i] = 0;

	//输出最终分类结果
		for (i = 0; i < kNearestNum; i++)
			neighbors[nearestCls[i]]++;

		str = "K个近邻中属于各个类的样本数为:\n";
		for (i = 0; i < classNum; i++)
		{
			_itoa(neighbors[i],buffer,9);
			str += buffer;
			str += '\t';
		}	
		MessageBox(str);

		for (i = 0; i < classNum; i++)
		{
			if(count < neighbors[i])
			{
				count = neighbors[i];
				nearestClass = i+1;
			}
		}

		if(count > refusePoint)
		{
			_itoa(nearestClass,buffer,9);
			str = "该样本被分类到第 ";
			str += buffer; 
			str += " 类中。";
			MessageBox(str);
		}	
		else
		{
			str = "该样本被分类到拒绝类中。";
			MessageBox(str);
		}

		if(neighbors)
		{
			delete[] neighbors;
			neighbors = 0;
		}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -