⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cnaivebayes.cpp

📁 贝叶斯公式
💻 CPP
📖 第 1 页 / 共 2 页
字号:
#include "CNaiveBayes.h"
#include "ssps.h"

#include <iostream>
#include <math.h>
#include <fstream>
#include <windows.h>
using namespace std;

extern Segment(int, string *,string ,string *);
//-------------------------------------------------
//CNaiveBayes类的构造函数
//-------------------------------------------------
CNaiveBayes::CNaiveBayes()
{
    m_nClassNum = 0;						//类别总数
	m_psClassName = NULL;
	m_mapClassName2ID.clear();
	m_ppfTrainRes = NULL;
	m_pnTrainNum = NULL;
	m_pfPrC = NULL;
	m_nCorrectResNum =0;
	m_pnResNum = NULL;
	m_pnCorrectResNum = NULL;

}
//-----------------------------------------------------------------------------------------//
//   功能:       CNaiveBayes类的析构函数。
//-----------------------------------------------------------------------------------------//
CNaiveBayes::~CNaiveBayes()
{
	if(m_psClassName != NULL)
	{
		delete[] m_psClassName;
	}
	m_psClassName = NULL;

	m_mapClassName2ID.clear();
	
	if(	m_ppfTrainRes != NULL)
	{
		for(int i=0;i<m_nClassNum;i++)
		{
			delete[] m_ppfTrainRes[i];
			m_ppfTrainRes[i] = NULL;
		}
		delete[] m_ppfTrainRes;
	}
	m_ppfTrainRes = NULL;

	if(m_pnTrainNum != NULL)
	{
		delete[] m_pnTrainNum;
	}
	m_pnTrainNum = NULL;

	if(m_pfPrC != NULL)
	{
		delete[] m_pfPrC;
	}
	m_pfPrC = NULL;

	m_vTestRes.clear();

	if(m_pnResNum != NULL)
	{
		delete[] m_pnResNum;
	}
	m_pnResNum = NULL;

	if(m_pnCorrectResNum != NULL)
	{
		delete[] m_pnCorrectResNum;
	}
	m_pnCorrectResNum = NULL;
}
//-----------------------------------------------------------------------------------------//
//   功能:       读入训练文本,训练分类器
//----------------------------------------------------------------------------------------//
int CNaiveBayes::Train()
{
	//从文件_all_ids.lst中读取训练语料的特征项词频,训练特征项的类内先验概率
 	ifstream tfile("..\\Dic\\DF\\_all_ids.lst");
	string strLine;		//读入一行为一个字符串
	string::size_type pos=0, prev_pos=0;
	string word;
	int tatal_num = 0;
	string sTemp="";
	int nClassIndex=0;
	while(getline(tfile,strLine,'\n')!=NULL)//读取行
	{
		//取类名
		string::size_type classname_pos=0;
		pos=0;
		prev_pos=0;
		classname_pos = strLine.find_first_of( ' ',classname_pos );//从字符串开始寻找空格,返回空格所在位置。
		sTemp = strLine.substr( prev_pos, classname_pos - prev_pos );
		//取对应的类标号
		nClassIndex = m_mapClassName2ID[sTemp];		

		string::size_type num_pos;
		//取类别中文章的数量
		num_pos = ++classname_pos;
		num_pos = strLine.find_first_of( ' ',num_pos );
		sTemp = strLine.substr( classname_pos,num_pos-classname_pos );
		m_pnTrainNum[nClassIndex] = atoi(sTemp.c_str());		//每个类别中文章的数量

		pos=++num_pos;
		prev_pos=pos;
		int nWordNum = 0;//类中词的数量
		while((pos = strLine.find_first_of( ' ', pos ))!=string::npos)
		{
			//取得每个word 的key:weight
			string::size_type key_pos=0, weight_pos=0;
			word = strLine.substr( prev_pos, pos - prev_pos );
			prev_pos = ++pos;
			
			//取得每个word的key,weight存入链表elem
			key_pos = word.find_first_of( ':', key_pos );
			string str_key = word.substr( 0 , key_pos );
			string str_weight = word.substr( key_pos+1, string::npos - key_pos );
			int key = atoi(str_key.c_str());
			int weight = atoi(str_weight.c_str());
            m_ppfTrainRes[nClassIndex][key]=weight;
			nWordNum += weight;			
		}
		
		//处理最后一个空格后的word
		word = strLine.substr( prev_pos, pos - prev_pos );
		//cout << word << endl;
		string::size_type key_pos=0;
		key_pos = word.find_first_of( ':', key_pos );
		string str_key = word.substr( 0 , key_pos );
		string str_weight = word.substr( key_pos+1, string::npos - key_pos );
		int key = atoi(str_key.c_str());
		int weight = atoi(str_weight.c_str());
		m_ppfTrainRes[nClassIndex][key]=weight;
		nWordNum += weight;	
		
		char szPrWFile[20];
		sprintf(szPrWFile,"Pr\\PrW in C%d.txt",nClassIndex);
		ofstream Fou;
		Fou.open(szPrWFile,ios::out);
		for(int k=0;k<m_nFeatureNum;k++)
		{
			m_ppfTrainRes[nClassIndex][k]=(m_ppfTrainRes[nClassIndex][k]+1)/(nWordNum+m_nFeatureNum);
			char szTemp[100];
			memset(szTemp,0,100);
			sprintf(szTemp,"%d\t%f\n",k,m_ppfTrainRes[nClassIndex][k]);
			Fou << szTemp;			
		}
		Fou.close();

		tatal_num += m_pnTrainNum[nClassIndex];
	}

	//-------------------------------------------------------------
	//计算每个类的先验概率PrC
	//-------------------------------------------------------------
	ofstream FouPrc;
	FouPrc.open("Pr\\Prc.txt",ios::out);
	for(int k=0 ; k<m_nClassNum ; k++)
	{
		m_pfPrC[k] = (float)m_pnTrainNum[k]/(float)tatal_num;
		char szPrcTemp[100];//存储类的先验概率
		memset(szPrcTemp,0,100);
		sprintf(szPrcTemp,"%d %s %f\n",k,m_psClassName[k].c_str(),m_pfPrC[k]);
		FouPrc << szPrcTemp;
	}
	FouPrc.close();
	return m_nClassNum;
}

//-----------------------------------------------------------------------------------------//
//   功能:       初始化测试结果
//----------------------------------------------------------------------------------------//
bool CNaiveBayes::InitTestRes()
{
	m_pnResNum = new int[m_nClassNum];
	memset(m_pnResNum,0,m_nClassNum*sizeof(int));

	m_pnCorrectResNum = new int[m_nClassNum];
	memset(m_pnCorrectResNum,0,m_nClassNum*sizeof(int));
	return true;
}

//----------------------------------------------------------------------------------------//
//	功能:		  对测试文本进行分类
//   参数:       无
//  (入口)
//	string sTestFile	待测试文本的文件名
//	bool bFlag=false	是否已知该测试文本的所属类别,是true,否false
//	int nClassID=0		测试文本所属类别的ID号,从0开始编号(与m_mapClassName2ID中的ID号保持一致)
//  (出口)      无
//    返回:      测试文本数量
//----------------------------------------------------------------------------------------//
int CNaiveBayes::Test(string sTestFile,bool bFlag,int nClassID)
{
	m_vTestRes.clear();
	m_nCorrectResNum = 0;
	string strLine;		//读入一行为一个字符串
	
	double *pro=new double[m_nClassNum];//存储各特征项取log后的和	
	memset(pro,0,m_nClassNum*sizeof(double));

	string::size_type pos=0, prev_pos=0;
	string word;
	//读取文件内容	
	ifstream tfile1(sTestFile.c_str());	
	//每次处理一个测试文本
	while(getline(tfile1,strLine,'\n')!=NULL)
	{		
		string::size_type wordnum_pos=0;
		wordnum_pos = strLine.find_first_of( ' ',wordnum_pos );
		
		//每篇文章的词数nWordNum		
		string sWordNum = strLine.substr( 0,wordnum_pos );
		int nWordNum = atoi( sWordNum.c_str() );
		
		pos=++wordnum_pos;
		prev_pos=pos;
		while((pos = strLine.find_first_of( ' ', pos ))!=string::npos)
		{
			//取得每个word 的key:weight
			string::size_type key_pos=0, weight_pos=0;
			word = strLine.substr( prev_pos, pos - prev_pos );
			prev_pos = ++pos;
			
			//取得每个word的key并计算在类中的先验概率
			key_pos = word.find_first_of( ':', key_pos );
			string str_key = word.substr( 0 , key_pos );
			int key = atoi(str_key.c_str());
			
			for(int i=0;i<m_nClassNum;i++)
			{	
				pro[i]=pro[i]+log(m_ppfTrainRes[i][key]);
			}
			
		}
		
		//处理最后一个空格后的word
		word = strLine.substr( prev_pos, pos - prev_pos );
		string::size_type key_pos=0;
		key_pos = word.find_first_of( ':', key_pos );
		string str_key = word.substr( 0 , key_pos );
		int key = atoi(str_key.c_str());
		
		
		for(int i=0;i<m_nClassNum;i++)
		{	
			pro[i]=pro[i]+log(m_ppfTrainRes[i][key])+log(m_pfPrC[i]);
		}
		
		//对测试文本进行分类
		
		double t;
		t = pro[0];
		int max_pro_num = 0;
		
		for(int s = 1 ; s < m_nClassNum ; s++ )
		{
			double k = pro[s];
			if( t < k && k != 0 && t != 0 )
				
			{
				max_pro_num = s;
				t = k;	//cout << "t=" << t << "k=" << k << endl;
			}
		}
		
		//打开C[max_pro_num].txt,记录下该文档的绝对路径
		
		m_vTestRes.push_back(max_pro_num);
		if (bFlag)
		{
			m_pnResNum[max_pro_num]++;
			if (max_pro_num==nClassID) 
			{
				m_nCorrectResNum++;
				m_pnCorrectResNum[nClassID]++;
			}
		}
	}	
	return m_vTestRes.size();		//返回文章数量
}

//-----------------------------------------------------------------------------------------//
//   功能:       将分类结果输出。
//----------------------------------------------------------------------------------------//
void CNaiveBayes::OutputRes(int nClassID,int nDocNum)
{
	cout << "类别 : " << m_psClassName[nClassID] << endl;
	cout << "总文档数: " << nDocNum << endl;
	cout << "划分为该类的文档数 = " << m_pnResNum[nClassID] << endl;
	cout << "正确归档数 = " << m_pnCorrectResNum[nClassID] << endl;
	double percent = (float)m_pnCorrectResNum[nClassID]/(float)m_pnResNum[nClassID]; 
	cout << "准确率: " << percent*100 << "%" << endl;
	double recall = (float)m_pnCorrectResNum[nClassID]/(float)nDocNum; 
	cout << "召回率: " << recall*100 << "%" << endl;
}

//-----------------------------------------------------------------------------------------//
//   功能:       对一个文件夹内的文档进行测试。
//   参数:       
//  (入口)
//	string sTestFilesPath	待测试文件夹
//----------------------------------------------------------------------------------------//
void CNaiveBayes::TestFiles(string sTestFilesPath)
{
   

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -