📄 textcat.cpp
字号:
#include "stdafx.h"
//#ifdef _DEBUG
//#define new DEBUG_NEW
//#endif
#pragma warning( disable : 4786 )
#include "TextCat.h"
#include "textprocess.h"
#include "knn.h"
#include "naivebayes.h"
#include "doclist.h"
void EvalPred( string, string, double&, double&, double&, double& );
CWinApp theApp;
using namespace std;
int _tmain(int argc, TCHAR* argv[], TCHAR* envp[])
{
int nRetCode = 0;
// initialize MFC and print and error on failure
if (!AfxWinInit(::GetModuleHandle(NULL), NULL, ::GetCommandLine(), 0))
{
// TODO: change error code to suit your needs
_tprintf(_T("Fatal Error: MFC initialization failed\n"));
nRetCode = 1;
}
CTextProcess TextProcess;
//processing training docs and testing docs
TextProcess.ProcessCorpus( ".\\train\\", ".\\test\\", ".\\", "train.doc.label", "test.doc.list", 1000 );
string sTestCatFile = ".\\test.doc.label";
string sPredFile = ".\\result.list";
string sTrainCatFile = ".\\train.doc.label";
if ( 1 ) {
string sTrainVectorFile = ".\\tfidf_train_vector.txt";
string sTestVectorFile = ".\\tfidf_test_vector.txt";
CDocList DocList;
//read training document vectors
DocList.ReadVector( sTrainVectorFile );
//read training document labels
DocList.ReadDocList( sTrainCatFile );
CKNN knn;
knn.k = 2;
knn.Train( DocList );
knn.ClassifyDocs( sTestVectorFile, sPredFile);
}
if ( 0 ) {
string sTrainVectorFile = ".\\tf_train_vector.txt";
string sTestVectorFile = ".\\tf_test_vector.txt";
CDocList DocList;
//read training document vectors
DocList.ReadVector( sTrainVectorFile );
//read training document labels
DocList.ReadDocList( sTrainCatFile );
CNaiveBayes NB;
NB.Train( DocList );
NB.ClassifyDocs( sTestVectorFile, sPredFile);
}
double MicroF1, MacroF1, MicroAccuracy, MacroAccuracy;
EvalPred( sTestCatFile, sPredFile, MicroF1, MacroF1, MicroAccuracy, MacroAccuracy );
cout << "Evaluation Result\tMicroF1=" << MicroF1 << "\tMacroF1=" <<MacroF1 << endl;
return nRetCode;
}
// ---Predicted Category
// P N
// ----------
// P |a | d|
// |---|----|
// N |b | c|
// ----------
//
//a:P->P b:N->P c:N->N d:P->N
void EvalPred( string sAnsFile, string sPredFile, double& MicroF1, double& MacroF1, double& MicroAccuracy, double& MacroAccuracy )
{
set<int> setAnsCat,setPredCat,setAllCat;
map<int,set<int> > mapDocId2Ans,mapDocId2Pred;
map<int,set<int> >::iterator it;
set<int>::iterator itSet;
vector< pair< set<int>,set<int> > > vResult;
string sLine, sAnsLine, sPredLine;
cout << "Evaluating result... " << endl;
ReadIdSetMap<int>( mapDocId2Ans, sAnsFile );
ReadIdSetMap<int>( mapDocId2Pred, sPredFile );
for ( it = mapDocId2Ans.begin(); it!=mapDocId2Ans.end(); it++)
{
int iDocId = it->first;
setAnsCat = it->second;
setPredCat = mapDocId2Pred[iDocId];
vResult.push_back( make_pair( setAnsCat, setPredCat ) );
for( itSet = setAnsCat.begin(); itSet!=setAnsCat.end(); itSet++ )
setAllCat.insert( *itSet );
setAnsCat.clear();
setPredCat.clear();
}
mapDocId2Ans.clear();
mapDocId2Pred.clear();
size_t iCatTotal = setAllCat.size();
//catid2[a b c d p r f1 err]
map<int,vector<double> > mapCatId2Tab;
for ( itSet = setAllCat.begin(); itSet!=setAllCat.end(); itSet++ )
{
int iCat = *itSet;
for ( int j=0;j<8;j++)
(mapCatId2Tab[ iCat ]).push_back( 0 );
}
float all_a,all_b,all_c,all_d;
all_a = all_b = all_c = all_d = 0;
for ( unsigned int i=0; i< vResult.size(); i++ )
{
setAnsCat = vResult[i].first;
setPredCat = vResult[i].second;
for ( itSet = setAllCat.begin(); itSet!=setAllCat.end(); itSet++ )
{
int iCatId = *itSet;
if ( setAnsCat.find( iCatId )!=setAnsCat.end() )
{
if ( setPredCat.find( iCatId )!=setPredCat.end() )
{
all_a += 1;
(mapCatId2Tab[ iCatId ])[0]++;
}
else //P->N,d
{
all_d += 1;
(mapCatId2Tab[ iCatId ])[3]++;
}
}
else //negative example
{
if ( setPredCat.find( iCatId )!=setPredCat.end() )
{
(mapCatId2Tab[ iCatId ])[1]++;
all_b += 1;
}
else //N->N,c
{
all_c += 1;
(mapCatId2Tab[ iCatId ])[2]++;
}
}
}
}
vResult.clear();
double a,b,c,d,p,r,f1,accuracy;
for ( itSet = setAllCat.begin(); itSet!=setAllCat.end(); itSet++ )
{
int iCatId = *itSet;
a = (mapCatId2Tab[iCatId])[0];
b = (mapCatId2Tab[iCatId])[1];
c = (mapCatId2Tab[iCatId])[2];
d = (mapCatId2Tab[iCatId])[3];
p = a*1.0/(a+b);
r = a*1.0/(a+d);
if ( a == 0)
f1 = 0;
else
f1 = 2.0*p*r/(p+r);
accuracy = ( a+c)*1.0/(a+b+c+d);
(mapCatId2Tab[iCatId])[4] = p;
(mapCatId2Tab[iCatId])[5] = r;
(mapCatId2Tab[iCatId])[6] = f1;
(mapCatId2Tab[iCatId])[7] = accuracy;
}
MacroF1 = MacroAccuracy = 0;
for ( itSet = setAllCat.begin(); itSet!=setAllCat.end(); itSet++ )
{
int iCatId = *itSet;
MacroF1 += (mapCatId2Tab[iCatId])[6];
MacroAccuracy += (mapCatId2Tab[iCatId])[7];
}
MacroF1 /= iCatTotal;
MacroAccuracy /= iCatTotal;
double MicroP = all_a*1.0/(all_a+all_b);
double MicroR = all_a*1.0/(all_a+all_d);
MicroF1 = 2*MicroP*MicroR/(MicroR+MicroP);
MicroAccuracy = (all_a + all_c)*1.0/(all_a + all_b + all_c + all_d);
cout << "Evaluation Result\tMicroF1=" << MicroF1 << "\tMacroF1=" <<MacroF1 << endl;
for ( itSet = setAnsCat.begin(); itSet!=setAnsCat.end(); itSet++ )
{
int iCatId = *itSet;
mapCatId2Tab[ iCatId ].clear();
}
mapCatId2Tab.clear();
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -