⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 knn.cpp

📁 knn和Native Bayes算法实现,两个实现在一起,是数据挖掘和机器学习中的内容.
💻 CPP
字号:
#include "stdafx.h"

#pragma warning( disable : 4786 )

#include "windows.h"
#include "knn.h"
#include "math.h"
#include "assert.h"


CKNN::CKNN()
{
	docs_train = NULL;
	vTargetTrain.clear();
	k = 1;
	lTrainTotal = 0;
}

CKNN::~CKNN()
{
	if (docs_train != NULL)
		free(docs_train);

	vTargetTrain.clear();
}

void CKNN::Train( CDocList& DocList )
{  
	size_t lDocTotal = DocList.vSDoc.size();
	docs_train = (DOC *)my_malloc(sizeof(DOC) * lDocTotal );
	vTargetTrain.resize( lDocTotal );

	lTrainTotal = 0;
	for (unsigned long l=0; l<lDocTotal; l++)
	{
		long lPos = DocList.mapDocId_Pos[ DocList.vSDoc[l].lDocId ];
		docs_train[l] = DocList.docs[lPos];
		vTargetTrain[l] = DocList.vSDoc[l].setDocCat;
		lTrainTotal++;
	}
}

struct K_NODE {
	//int iIndex;
	float fSim;
	set<int> setDocCat;
};

struct CKNode_More_Than_by_Sim : public binary_function<K_NODE, K_NODE, BOOL> 
{
	bool operator()(K_NODE x, K_NODE y) { return x.fSim > y.fSim; }
};

int CKNN::ClassifyDocs( DOC& test_doc )
{  
	assert((k>0) && (lTrainTotal>0));

	vector<K_NODE> vecKNode;
	float fSim;
	K_NODE knode;

	for (long lTrainIndex = 0;lTrainIndex<lTrainTotal;lTrainIndex++)
	{
		fSim = CalcSim(docs_train[lTrainIndex].content, test_doc.content);

		knode.fSim = fSim;
		knode.setDocCat = vTargetTrain[lTrainIndex];
		//knode.iIndex = lTrainIndex;

		vecKNode.push_back(knode);

		sort(vecKNode.begin(),vecKNode.end(),CKNode_More_Than_by_Sim());

		if (vecKNode.size()>k)
			vecKNode.erase(vecKNode.begin() + vecKNode.size()-1);
	}

	//for(int i=0;i<vecKNode.size();i++)
	//{
	//	cout << vecKNode[i].iCat << ":" << vecKNode[i].fSim << " ";
	//}
	//cout << endl;

	map<int,float> mapCat2Score;
	map<int,float>::iterator it;

	for (unsigned int i=0;i<vecKNode.size();i++)
	{
		set<int> setCat = vecKNode[i].setDocCat;

		set<int>::iterator itSet;

		for ( itSet=setCat.begin(); itSet!=setCat.end(); itSet++ )
		{
			int iCurCat = *itSet;

			if (mapCat2Score.find(iCurCat)==mapCat2Score.end())
				mapCat2Score[iCurCat] = 1;
			else
				mapCat2Score[iCurCat] ++;
		}
	}

	int iTrueCat = -1;
	float fScore = 0.0;
	for (it = mapCat2Score.begin();it!=mapCat2Score.end();it++)
	{
		if (iTrueCat==-1)
			iTrueCat = it->first;

		if (it->second>fScore) {
			fScore = it->second;
			iTrueCat = it->first;
		}
	}

	mapCat2Score.clear();
	vecKNode.clear();

	return iTrueCat;
}

void CKNN::ClassifyDocs( string sVectorFile, string sResultFile )
{
	ofstream ofResult( sResultFile.c_str() );
	ifstream ifVector( sVectorFile.c_str() );
	string sLine;

	int iCount =0;
	while ( getline( ifVector, sLine )) {
		DOC test_doc;
		int iParseResult = ReadDoc( sLine, test_doc );

		if ( iParseResult>0 ) {
			int iCat = ClassifyDocs( test_doc );
			ofResult << test_doc.DocId << "\t" << iCat << endl;
		}
		else
		{
			ofResult << test_doc.DocId << endl;
			continue;
		}

		free( test_doc.content );		

		iCount++;
		cout << "documents classified: " << iCount << "\r";
	}
	ifVector.close();
	ofResult.close();
}

float CKNN::CalcSim(WORDITEM *p, WORDITEM *q)
{
	float f;

	f = sprod_ss(p,q);
	f /= sqrt(sprod_ss(p,p));
	f /= sqrt(sprod_ss(q,q));

	return f;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -