⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 textcat.cpp

📁 knn和Native Bayes算法实现,两个实现在一起,是数据挖掘和机器学习中的内容.
💻 CPP
字号:
#include "stdafx.h"

//#ifdef _DEBUG
//#define new DEBUG_NEW
//#endif
#pragma warning( disable : 4786 )

#include "TextCat.h"
#include "textprocess.h"
#include "knn.h"
#include "naivebayes.h"
#include "doclist.h"


void EvalPred( string, string, double&, double&, double&, double& );

CWinApp theApp;

using namespace std;

int _tmain(int argc, TCHAR* argv[], TCHAR* envp[])
{
	int nRetCode = 0;

	// initialize MFC and print and error on failure
	if (!AfxWinInit(::GetModuleHandle(NULL), NULL, ::GetCommandLine(), 0))
	{
		// TODO: change error code to suit your needs
		_tprintf(_T("Fatal Error: MFC initialization failed\n"));
		nRetCode = 1;
	}

	CTextProcess TextProcess; 

	//processing training docs and testing docs
	TextProcess.ProcessCorpus( ".\\train\\", ".\\test\\", ".\\", "train.doc.label", "test.doc.list", 1000 );

	string sTestCatFile = ".\\test.doc.label";
	string sPredFile = ".\\result.list";

	string sTrainCatFile = ".\\train.doc.label";

	if ( 1 ) {
		string sTrainVectorFile = ".\\tfidf_train_vector.txt";
		string sTestVectorFile = ".\\tfidf_test_vector.txt";

		CDocList DocList;

		//read training document vectors
		DocList.ReadVector( sTrainVectorFile );

		//read training document labels
		DocList.ReadDocList( sTrainCatFile );

		CKNN knn;

		knn.k = 2;

		knn.Train( DocList );

		knn.ClassifyDocs( sTestVectorFile, sPredFile);
	}

	if ( 0 ) {
		string sTrainVectorFile = ".\\tf_train_vector.txt";
		string sTestVectorFile = ".\\tf_test_vector.txt";

		CDocList DocList;

		//read training document vectors
		DocList.ReadVector( sTrainVectorFile );

		//read training document labels
		DocList.ReadDocList( sTrainCatFile );

		CNaiveBayes NB;

		NB.Train( DocList );

		NB.ClassifyDocs( sTestVectorFile, sPredFile);
	}


	double MicroF1, MacroF1, MicroAccuracy, MacroAccuracy;

	EvalPred( sTestCatFile, sPredFile, MicroF1, MacroF1, MicroAccuracy, MacroAccuracy );

	cout << "Evaluation Result\tMicroF1=" << MicroF1 << "\tMacroF1=" <<MacroF1 << endl;

	return nRetCode;
}


//	---Predicted Category
//		P		 N
//		----------
//	P	|a	|	d|
//		|---|----|
//	N	|b	|	c|
//		----------
//
//a:P->P	b:N->P	c:N->N	d:P->N
void EvalPred( string sAnsFile, string sPredFile, double& MicroF1, double& MacroF1, double& MicroAccuracy, double& MacroAccuracy )
{
	set<int> setAnsCat,setPredCat,setAllCat;
	map<int,set<int> > mapDocId2Ans,mapDocId2Pred;
	map<int,set<int> >::iterator it;
	set<int>::iterator itSet;
	vector< pair< set<int>,set<int> > > vResult;

	string sLine, sAnsLine, sPredLine;
	
	cout << "Evaluating result... " << endl;

	ReadIdSetMap<int>( mapDocId2Ans, sAnsFile );
	ReadIdSetMap<int>( mapDocId2Pred, sPredFile );

	for ( it = mapDocId2Ans.begin(); it!=mapDocId2Ans.end(); it++)
	{
		int iDocId = it->first;

		setAnsCat = it->second;
		setPredCat = mapDocId2Pred[iDocId];

		vResult.push_back( make_pair( setAnsCat, setPredCat ) );

		for( itSet = setAnsCat.begin(); itSet!=setAnsCat.end(); itSet++ )
			setAllCat.insert( *itSet );

		setAnsCat.clear();
		setPredCat.clear();
	}

	mapDocId2Ans.clear();
	mapDocId2Pred.clear();

	size_t iCatTotal = setAllCat.size();
	//catid2[a b c d p r f1 err]
	map<int,vector<double> > mapCatId2Tab;

	for ( itSet = setAllCat.begin(); itSet!=setAllCat.end(); itSet++ )
	{
		int iCat = *itSet;

		for ( int j=0;j<8;j++)
			(mapCatId2Tab[ iCat ]).push_back( 0 );
	}

	float all_a,all_b,all_c,all_d;

	all_a = all_b = all_c = all_d = 0;

	for ( unsigned int i=0; i< vResult.size(); i++ )
	{
		setAnsCat = vResult[i].first;
		setPredCat = vResult[i].second;

		for ( itSet = setAllCat.begin(); itSet!=setAllCat.end(); itSet++ )
		{
			int iCatId = *itSet;

			if ( setAnsCat.find( iCatId )!=setAnsCat.end() )
			{
				if ( setPredCat.find( iCatId )!=setPredCat.end() )
				{
					all_a += 1;
					(mapCatId2Tab[ iCatId ])[0]++;
				}
				else //P->N,d
				{
					all_d += 1;
					(mapCatId2Tab[ iCatId ])[3]++;
				}
			}
			else //negative example
			{
				if ( setPredCat.find( iCatId )!=setPredCat.end() )
				{
					(mapCatId2Tab[ iCatId ])[1]++;
					all_b += 1;
				}
				else //N->N,c
				{
					all_c += 1;
					(mapCatId2Tab[ iCatId ])[2]++;
				}
			}
		}
	}
	vResult.clear();

	double a,b,c,d,p,r,f1,accuracy;

	for ( itSet = setAllCat.begin(); itSet!=setAllCat.end(); itSet++ )
	{
		int iCatId  = *itSet;

		a = (mapCatId2Tab[iCatId])[0];
		b = (mapCatId2Tab[iCatId])[1];
		c = (mapCatId2Tab[iCatId])[2];
		d = (mapCatId2Tab[iCatId])[3];

		p = a*1.0/(a+b);
		r = a*1.0/(a+d);
		
		if ( a == 0)
			f1 = 0;
		else
			f1 = 2.0*p*r/(p+r);

		accuracy = ( a+c)*1.0/(a+b+c+d);

		(mapCatId2Tab[iCatId])[4] = p;
		(mapCatId2Tab[iCatId])[5] = r;
		(mapCatId2Tab[iCatId])[6] = f1;
		(mapCatId2Tab[iCatId])[7] =	accuracy;
	}

	MacroF1 = MacroAccuracy = 0;
	for ( itSet = setAllCat.begin(); itSet!=setAllCat.end(); itSet++ )
	{
		int iCatId  = *itSet;
		
		MacroF1 += (mapCatId2Tab[iCatId])[6];
		MacroAccuracy += (mapCatId2Tab[iCatId])[7];
	}

	MacroF1 /= iCatTotal;
	MacroAccuracy /= iCatTotal;

	double MicroP = all_a*1.0/(all_a+all_b);
	double MicroR = all_a*1.0/(all_a+all_d);
	MicroF1 = 2*MicroP*MicroR/(MicroR+MicroP);	
	MicroAccuracy = (all_a + all_c)*1.0/(all_a + all_b + all_c + all_d);

cout << "Evaluation Result\tMicroF1=" << MicroF1 << "\tMacroF1=" <<MacroF1 << endl;

	for ( itSet = setAnsCat.begin(); itSet!=setAnsCat.end(); itSet++ )
	{
		int iCatId  = *itSet;
		mapCatId2Tab[ iCatId ].clear();	
	}
	mapCatId2Tab.clear();
}


⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -