📄 naivebayes.cpp
字号:
#include "stdafx.h"
#pragma warning( disable : 4786 )
#include "naivebayes.h"
#include "math.h"
#include <set>
#include <map>
#include <algorithm>
using namespace std;
CNaiveBayes::CNaiveBayes()
{
m_DocCnt = 0;
m_CatCnt = 0;
}
CNaiveBayes::~CNaiveBayes()
{
mapCatId_CatProb.clear();
mapCatId_WordCnt.clear();
mapCatId_DocCnt.clear();
mapCatId_WIdProb.clear();
}
void CNaiveBayes::Train( CDocList &DocList )
{
size_t lDocTotal = DocList.vSDoc.size();
set<int>::iterator itSet;
int iWId;
long lWordCnt;
m_DocCnt = 0;
for (unsigned long l=0; l<lDocTotal; l++)
{
long lPos = DocList.mapDocId_Pos[ DocList.vSDoc[l].lDocId ];
WORDITEM *pWordItem = DocList.docs[lPos].content;
for ( int i=0;i<DocList.docs[lPos].dim_content;i++)
{
iWId = pWordItem[i].wnum;
lWordCnt = (long)pWordItem[i].weight;
setWId.insert( iWId );
//cout << DocList.vSDoc[l].setDocCat.size() << endl;
for ( itSet=DocList.vSDoc[l].setDocCat.begin(); itSet!=DocList.vSDoc[l].setDocCat.end(); itSet++ )
{
int iCatId = *itSet;
setCat.insert( iCatId );
m_DocCnt++;
if ( mapCatId_WordCnt.find( iCatId ) == mapCatId_WordCnt.end() ) {
mapCatId_WordCnt[ iCatId ] = lWordCnt;
}
else
{
mapCatId_WordCnt[ iCatId ] += lWordCnt;
}
pair<int,int> pairWidCatId = make_pair( iWId, iCatId );
if ( mapCatId_WIdProb.find( iCatId )==mapCatId_WIdProb.end() ) {
(mapCatId_WIdProb[ iCatId ])[ iWId ] = lWordCnt;
}
else
{
if ( mapCatId_WIdProb[ iCatId ].find( iWId ) == mapCatId_WIdProb[ iCatId ].end() ) {
(mapCatId_WIdProb[ iCatId ])[ iWId ] = lWordCnt;
}
else
{
( mapCatId_WIdProb[ iCatId ])[ iWId ] += lWordCnt;
}
}
}
}
for ( itSet=DocList.vSDoc[l].setDocCat.begin(); itSet!=DocList.vSDoc[l].setDocCat.end(); itSet++ )
{
int iCatId = *itSet;
if ( mapCatId_DocCnt.find( iCatId ) == mapCatId_DocCnt.end() ) {
mapCatId_DocCnt[ iCatId ] = 1;
}
else
{
mapCatId_DocCnt[ iCatId ] += 1;
}
}
}
map<int,long>::iterator it,it2;
for( it=mapCatId_DocCnt.begin(); it!=mapCatId_DocCnt.end(); it++ )
{
mapCatId_CatProb[ it->first ] = it->second / (m_DocCnt*1.0);
}
map<int,map<int,double> >::iterator itCatId_WidProb;
for ( itCatId_WidProb = mapCatId_WIdProb.begin(); itCatId_WidProb!=mapCatId_WIdProb.end(); itCatId_WidProb++ )
{
int iCatId = itCatId_WidProb->first;
map<int,double>::iterator itWidProb;
set<int>::iterator itSet;
for ( itSet=setWId.begin(); itSet!=setWId.end(); itSet++ )
{
iWId = *itSet;
double dbProb = (mapCatId_WIdProb[ iCatId ])[iWId];
if ( mapCatId_WIdProb[ iCatId ].find( iWId ) != mapCatId_WIdProb[ iCatId ].end() ) {
dbProb = (1.0+dbProb)/( mapCatId_WordCnt[ iCatId ] + setWId.size() );
//cout << iWId << "\t" << dbProb << endl;
(mapCatId_WIdProb[ iCatId ])[iWId] = log10( dbProb );
}
else
{
dbProb = 1.0/( mapCatId_WordCnt[ iCatId ]+setWId.size());
//cout << iWId << "\t" << dbProb << endl;
(mapCatId_WIdProb[ iCatId ])[iWId] = log10( dbProb );
}
}
}
}
struct Pair_More_Than_by_Score : public binary_function<pair<int,double>, pair<int,double>, BOOL>
{
bool operator()(pair<int,double> x, pair<int,double> y) { return x.second > y.second; }
};
int CNaiveBayes::ClassifyDocs( DOC& test_doc )
{
vector< pair<int,double> > vCatIdScore;
double dbScore;
set<int>::iterator it;
for ( it=setCat.begin();it!=setCat.end(); it++ )
{
int iCatId = *it;
dbScore = log10(mapCatId_CatProb[ iCatId ]);
WORDITEM *pWordItem = test_doc.content;
for ( int i=0;i<test_doc.dim_content;i++)
{
int iWId = pWordItem[i].wnum;
long lWordCnt = (long)pWordItem[i].weight;
dbScore += lWordCnt * (mapCatId_WIdProb[iCatId])[ iWId ];
if ( i==40 ) {
i = i;
}
}
vCatIdScore.push_back( make_pair( iCatId, dbScore ) );
}
sort( vCatIdScore.begin(), vCatIdScore.end(), Pair_More_Than_by_Score() );
//for ( int j=0;j<vCatIdScore.size();j++)
// cout << vCatIdScore[j].first << "\t" << vCatIdScore[j].second << endl;
int iCatId = vCatIdScore[0].first;
vCatIdScore.clear();
return iCatId;
}
void CNaiveBayes::ClassifyDocs( string sVectorFile, string sResultFile )
{
ofstream ofResult( sResultFile.c_str() );
ifstream ifVector( sVectorFile.c_str() );
string sLine;
int iCount =0;
while ( getline( ifVector, sLine )) {
DOC test_doc;
int iParseResult = ReadDoc( sLine, test_doc );
if ( iParseResult>0 ) {
int iCat = ClassifyDocs( test_doc );
ofResult << test_doc.DocId << "\t" << iCat << endl;
}
else
{
ofResult << test_doc.DocId << endl;
continue;
}
free( test_doc.content );
iCount++;
cout << "documents classified: " << iCount << "\r";
}
ifVector.close();
ofResult.close();
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -