⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 naivebayes.cpp

📁 knn和Native Bayes算法实现,两个实现在一起,是数据挖掘和机器学习中的内容.
💻 CPP
字号:
#include "stdafx.h"

#pragma warning( disable : 4786 )

#include "naivebayes.h"


#include "math.h"
#include <set>
#include <map>
#include <algorithm>


using namespace std;

CNaiveBayes::CNaiveBayes()
{
	m_DocCnt = 0;
	m_CatCnt = 0;
}

CNaiveBayes::~CNaiveBayes()
{
	mapCatId_CatProb.clear();
	mapCatId_WordCnt.clear();
	mapCatId_DocCnt.clear();
	mapCatId_WIdProb.clear();
}

void CNaiveBayes::Train( CDocList &DocList )
{
	size_t lDocTotal = DocList.vSDoc.size();
	set<int>::iterator itSet;
	int iWId;
	long lWordCnt;

	m_DocCnt = 0;
	for (unsigned long l=0; l<lDocTotal; l++)
	{
		long lPos = DocList.mapDocId_Pos[ DocList.vSDoc[l].lDocId ];

		WORDITEM *pWordItem = DocList.docs[lPos].content;

		for ( int i=0;i<DocList.docs[lPos].dim_content;i++)
		{
			iWId = pWordItem[i].wnum;
			lWordCnt = (long)pWordItem[i].weight;

			setWId.insert( iWId );

			//cout << DocList.vSDoc[l].setDocCat.size() << endl;

			for ( itSet=DocList.vSDoc[l].setDocCat.begin(); itSet!=DocList.vSDoc[l].setDocCat.end(); itSet++ )
			{
				int iCatId = *itSet;

				setCat.insert( iCatId );

				m_DocCnt++;
				
				if ( mapCatId_WordCnt.find( iCatId ) == mapCatId_WordCnt.end() ) {
					mapCatId_WordCnt[ iCatId ] = lWordCnt;
				}
				else
				{
					mapCatId_WordCnt[ iCatId ] += lWordCnt;
				}

				pair<int,int> pairWidCatId = make_pair( iWId, iCatId );

				if ( mapCatId_WIdProb.find( iCatId )==mapCatId_WIdProb.end() ) {
					(mapCatId_WIdProb[ iCatId ])[ iWId ] = lWordCnt;
				}
				else
				{
					if ( mapCatId_WIdProb[ iCatId ].find( iWId ) == mapCatId_WIdProb[ iCatId ].end() ) {
						(mapCatId_WIdProb[ iCatId ])[ iWId ] = lWordCnt;
					}
					else
					{
						( mapCatId_WIdProb[ iCatId ])[ iWId ] += lWordCnt;
					}
				}
			}
		}

		for ( itSet=DocList.vSDoc[l].setDocCat.begin(); itSet!=DocList.vSDoc[l].setDocCat.end(); itSet++ )
		{
			int iCatId = *itSet;

			if ( mapCatId_DocCnt.find( iCatId ) == mapCatId_DocCnt.end() ) {
				mapCatId_DocCnt[ iCatId ] = 1;
			}
			else
			{
				mapCatId_DocCnt[ iCatId ] += 1;
			}
		}
	}

	map<int,long>::iterator it,it2;
	for( it=mapCatId_DocCnt.begin(); it!=mapCatId_DocCnt.end(); it++ )
	{
		mapCatId_CatProb[ it->first ] = it->second / (m_DocCnt*1.0);
	}

	map<int,map<int,double> >::iterator itCatId_WidProb;
	for ( itCatId_WidProb = mapCatId_WIdProb.begin(); itCatId_WidProb!=mapCatId_WIdProb.end(); itCatId_WidProb++ )
	{
		int iCatId = itCatId_WidProb->first;

		map<int,double>::iterator itWidProb;
		set<int>::iterator itSet;

		for ( itSet=setWId.begin(); itSet!=setWId.end(); itSet++ )
		{
			iWId = *itSet;

			double dbProb = (mapCatId_WIdProb[ iCatId ])[iWId];

			if ( mapCatId_WIdProb[ iCatId ].find( iWId ) != mapCatId_WIdProb[ iCatId ].end() ) {
				dbProb = (1.0+dbProb)/( mapCatId_WordCnt[ iCatId ] + setWId.size() );

				//cout << iWId << "\t" << dbProb << endl;

				(mapCatId_WIdProb[ iCatId ])[iWId] = log10( dbProb );
			}
			else
			{
				dbProb = 1.0/( mapCatId_WordCnt[ iCatId ]+setWId.size());
				
				//cout << iWId << "\t" << dbProb << endl;

				(mapCatId_WIdProb[ iCatId ])[iWId] = log10( dbProb );
			}
		}
	}
}

struct Pair_More_Than_by_Score : public binary_function<pair<int,double>, pair<int,double>, BOOL> 
{
	bool operator()(pair<int,double> x, pair<int,double> y) { return x.second > y.second; }
};

int CNaiveBayes::ClassifyDocs( DOC& test_doc )
{
	vector< pair<int,double> > vCatIdScore;

	double dbScore;
	set<int>::iterator it;
	for ( it=setCat.begin();it!=setCat.end(); it++ )
	{
		int iCatId = *it;

		dbScore = log10(mapCatId_CatProb[ iCatId ]);

		WORDITEM *pWordItem = test_doc.content;

		for ( int i=0;i<test_doc.dim_content;i++)
		{
			int iWId = pWordItem[i].wnum;
			long lWordCnt = (long)pWordItem[i].weight;

			dbScore += lWordCnt * (mapCatId_WIdProb[iCatId])[ iWId ];

			if ( i==40 ) {
				i = i;
			}
		}

		vCatIdScore.push_back( make_pair( iCatId, dbScore ) );
	}

	sort( vCatIdScore.begin(), vCatIdScore.end(), Pair_More_Than_by_Score() );

	//for ( int j=0;j<vCatIdScore.size();j++)
	//	cout << vCatIdScore[j].first << "\t" << vCatIdScore[j].second << endl;

	int iCatId = vCatIdScore[0].first;

	vCatIdScore.clear();

	return iCatId;
}


void CNaiveBayes::ClassifyDocs( string sVectorFile, string sResultFile )
{
	ofstream ofResult( sResultFile.c_str() );
	ifstream ifVector( sVectorFile.c_str() );
	string sLine;

	int iCount =0;
	while ( getline( ifVector, sLine )) {
		DOC test_doc;
		int iParseResult = ReadDoc( sLine, test_doc );

		if ( iParseResult>0 ) {
			int iCat = ClassifyDocs( test_doc );
			ofResult << test_doc.DocId << "\t" << iCat << endl;
		}
		else
		{
			ofResult << test_doc.DocId << endl;
			continue;
		}

		free( test_doc.content );		

		iCount++;
		cout << "documents classified: " << iCount << "\r";
	}
	ifVector.close();
	ofResult.close();
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -