📄 cnaivebayes.cpp
字号:
#include "CNaiveBayes.h"
#include "ssps.h"
#include <iostream>
#include <math.h>
#include <fstream>
#include <windows.h>
using namespace std;
extern Segment(int, string *,string ,string *);
//-------------------------------------------------
//CNaiveBayes类的构造函数
//-------------------------------------------------
CNaiveBayes::CNaiveBayes()
{
m_nClassNum = 0; //类别总数
m_psClassName = NULL;
m_mapClassName2ID.clear();
m_ppfTrainRes = NULL;
m_pnTrainNum = NULL;
m_pfPrC = NULL;
m_nCorrectResNum =0;
m_pnResNum = NULL;
m_pnCorrectResNum = NULL;
}
//-----------------------------------------------------------------------------------------//
// 功能: CNaiveBayes类的析构函数。
//-----------------------------------------------------------------------------------------//
CNaiveBayes::~CNaiveBayes()
{
if(m_psClassName != NULL)
{
delete[] m_psClassName;
}
m_psClassName = NULL;
m_mapClassName2ID.clear();
if( m_ppfTrainRes != NULL)
{
for(int i=0;i<m_nClassNum;i++)
{
delete[] m_ppfTrainRes[i];
m_ppfTrainRes[i] = NULL;
}
delete[] m_ppfTrainRes;
}
m_ppfTrainRes = NULL;
if(m_pnTrainNum != NULL)
{
delete[] m_pnTrainNum;
}
m_pnTrainNum = NULL;
if(m_pfPrC != NULL)
{
delete[] m_pfPrC;
}
m_pfPrC = NULL;
m_vTestRes.clear();
if(m_pnResNum != NULL)
{
delete[] m_pnResNum;
}
m_pnResNum = NULL;
if(m_pnCorrectResNum != NULL)
{
delete[] m_pnCorrectResNum;
}
m_pnCorrectResNum = NULL;
}
//-----------------------------------------------------------------------------------------//
// 功能: 读入训练文本,训练分类器
//----------------------------------------------------------------------------------------//
int CNaiveBayes::Train()
{
//从文件_all_ids.lst中读取训练语料的特征项词频,训练特征项的类内先验概率
ifstream tfile("..\\Dic\\DF\\_all_ids.lst");
string strLine; //读入一行为一个字符串
string::size_type pos=0, prev_pos=0;
string word;
int tatal_num = 0;
string sTemp="";
int nClassIndex=0;
while(getline(tfile,strLine,'\n')!=NULL)//读取行
{
//取类名
string::size_type classname_pos=0;
pos=0;
prev_pos=0;
classname_pos = strLine.find_first_of( ' ',classname_pos );//从字符串开始寻找空格,返回空格所在位置。
sTemp = strLine.substr( prev_pos, classname_pos - prev_pos );
//取对应的类标号
nClassIndex = m_mapClassName2ID[sTemp];
string::size_type num_pos;
//取类别中文章的数量
num_pos = ++classname_pos;
num_pos = strLine.find_first_of( ' ',num_pos );
sTemp = strLine.substr( classname_pos,num_pos-classname_pos );
m_pnTrainNum[nClassIndex] = atoi(sTemp.c_str()); //每个类别中文章的数量
pos=++num_pos;
prev_pos=pos;
int nWordNum = 0;//类中词的数量
while((pos = strLine.find_first_of( ' ', pos ))!=string::npos)
{
//取得每个word 的key:weight
string::size_type key_pos=0, weight_pos=0;
word = strLine.substr( prev_pos, pos - prev_pos );
prev_pos = ++pos;
//取得每个word的key,weight存入链表elem
key_pos = word.find_first_of( ':', key_pos );
string str_key = word.substr( 0 , key_pos );
string str_weight = word.substr( key_pos+1, string::npos - key_pos );
int key = atoi(str_key.c_str());
int weight = atoi(str_weight.c_str());
m_ppfTrainRes[nClassIndex][key]=weight;
nWordNum += weight;
}
//处理最后一个空格后的word
word = strLine.substr( prev_pos, pos - prev_pos );
//cout << word << endl;
string::size_type key_pos=0;
key_pos = word.find_first_of( ':', key_pos );
string str_key = word.substr( 0 , key_pos );
string str_weight = word.substr( key_pos+1, string::npos - key_pos );
int key = atoi(str_key.c_str());
int weight = atoi(str_weight.c_str());
m_ppfTrainRes[nClassIndex][key]=weight;
nWordNum += weight;
char szPrWFile[20];
sprintf(szPrWFile,"Pr\\PrW in C%d.txt",nClassIndex);
ofstream Fou;
Fou.open(szPrWFile,ios::out);
for(int k=0;k<m_nFeatureNum;k++)
{
m_ppfTrainRes[nClassIndex][k]=(m_ppfTrainRes[nClassIndex][k]+1)/(nWordNum+m_nFeatureNum);
char szTemp[100];
memset(szTemp,0,100);
sprintf(szTemp,"%d\t%f\n",k,m_ppfTrainRes[nClassIndex][k]);
Fou << szTemp;
}
Fou.close();
tatal_num += m_pnTrainNum[nClassIndex];
}
//-------------------------------------------------------------
//计算每个类的先验概率PrC
//-------------------------------------------------------------
ofstream FouPrc;
FouPrc.open("Pr\\Prc.txt",ios::out);
for(int k=0 ; k<m_nClassNum ; k++)
{
m_pfPrC[k] = (float)m_pnTrainNum[k]/(float)tatal_num;
char szPrcTemp[100];//存储类的先验概率
memset(szPrcTemp,0,100);
sprintf(szPrcTemp,"%d %s %f\n",k,m_psClassName[k].c_str(),m_pfPrC[k]);
FouPrc << szPrcTemp;
}
FouPrc.close();
return m_nClassNum;
}
//-----------------------------------------------------------------------------------------//
// 功能: 初始化测试结果
//----------------------------------------------------------------------------------------//
bool CNaiveBayes::InitTestRes()
{
m_pnResNum = new int[m_nClassNum];
memset(m_pnResNum,0,m_nClassNum*sizeof(int));
m_pnCorrectResNum = new int[m_nClassNum];
memset(m_pnCorrectResNum,0,m_nClassNum*sizeof(int));
return true;
}
//----------------------------------------------------------------------------------------//
// 功能: 对测试文本进行分类
// 参数: 无
// (入口)
// string sTestFile 待测试文本的文件名
// bool bFlag=false 是否已知该测试文本的所属类别,是true,否false
// int nClassID=0 测试文本所属类别的ID号,从0开始编号(与m_mapClassName2ID中的ID号保持一致)
// (出口) 无
// 返回: 测试文本数量
//----------------------------------------------------------------------------------------//
int CNaiveBayes::Test(string sTestFile,bool bFlag,int nClassID)
{
m_vTestRes.clear();
m_nCorrectResNum = 0;
string strLine; //读入一行为一个字符串
double *pro=new double[m_nClassNum];//存储各特征项取log后的和
memset(pro,0,m_nClassNum*sizeof(double));
string::size_type pos=0, prev_pos=0;
string word;
//读取文件内容
ifstream tfile1(sTestFile.c_str());
//每次处理一个测试文本
while(getline(tfile1,strLine,'\n')!=NULL)
{
string::size_type wordnum_pos=0;
wordnum_pos = strLine.find_first_of( ' ',wordnum_pos );
//每篇文章的词数nWordNum
string sWordNum = strLine.substr( 0,wordnum_pos );
int nWordNum = atoi( sWordNum.c_str() );
pos=++wordnum_pos;
prev_pos=pos;
while((pos = strLine.find_first_of( ' ', pos ))!=string::npos)
{
//取得每个word 的key:weight
string::size_type key_pos=0, weight_pos=0;
word = strLine.substr( prev_pos, pos - prev_pos );
prev_pos = ++pos;
//取得每个word的key并计算在类中的先验概率
key_pos = word.find_first_of( ':', key_pos );
string str_key = word.substr( 0 , key_pos );
int key = atoi(str_key.c_str());
for(int i=0;i<m_nClassNum;i++)
{
pro[i]=pro[i]+log(m_ppfTrainRes[i][key]);
}
}
//处理最后一个空格后的word
word = strLine.substr( prev_pos, pos - prev_pos );
string::size_type key_pos=0;
key_pos = word.find_first_of( ':', key_pos );
string str_key = word.substr( 0 , key_pos );
int key = atoi(str_key.c_str());
for(int i=0;i<m_nClassNum;i++)
{
pro[i]=pro[i]+log(m_ppfTrainRes[i][key])+log(m_pfPrC[i]);
}
//对测试文本进行分类
double t;
t = pro[0];
int max_pro_num = 0;
for(int s = 1 ; s < m_nClassNum ; s++ )
{
double k = pro[s];
if( t < k && k != 0 && t != 0 )
{
max_pro_num = s;
t = k; //cout << "t=" << t << "k=" << k << endl;
}
}
//打开C[max_pro_num].txt,记录下该文档的绝对路径
m_vTestRes.push_back(max_pro_num);
if (bFlag)
{
m_pnResNum[max_pro_num]++;
if (max_pro_num==nClassID)
{
m_nCorrectResNum++;
m_pnCorrectResNum[nClassID]++;
}
}
}
return m_vTestRes.size(); //返回文章数量
}
//-----------------------------------------------------------------------------------------//
// 功能: 将分类结果输出。
//----------------------------------------------------------------------------------------//
void CNaiveBayes::OutputRes(int nClassID,int nDocNum)
{
cout << "类别 : " << m_psClassName[nClassID] << endl;
cout << "总文档数: " << nDocNum << endl;
cout << "划分为该类的文档数 = " << m_pnResNum[nClassID] << endl;
cout << "正确归档数 = " << m_pnCorrectResNum[nClassID] << endl;
double percent = (float)m_pnCorrectResNum[nClassID]/(float)m_pnResNum[nClassID];
cout << "准确率: " << percent*100 << "%" << endl;
double recall = (float)m_pnCorrectResNum[nClassID]/(float)nDocNum;
cout << "召回率: " << recall*100 << "%" << endl;
}
//-----------------------------------------------------------------------------------------//
// 功能: 对一个文件夹内的文档进行测试。
// 参数:
// (入口)
// string sTestFilesPath 待测试文件夹
//----------------------------------------------------------------------------------------//
void CNaiveBayes::TestFiles(string sTestFilesPath)
{
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -