📄 classifier.cpp
字号:
// Classifier.cpp: implementation of the CClassifier class.
//
//////////////////////////////////////////////////////////////////////
#include "stdafx.h"
#include "svmcls.h"
#include "Classifier.h"
#include "WordSegment.h"
#include "Message.h"
#include <math.h>
#include <direct.h>
#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif
CClassifier theClassifier;
const DWORD CClassifier::dwModelFileID=0xFFEFFFFF;
//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////
CClassifier::CClassifier()
{
m_pDocs=NULL;
m_pSimilarityRatio=NULL;
m_lDocNum=0;
m_nClassNum=0;
}
CClassifier::~CClassifier()
{
}
//参数bGenDic=false代表无需重新扫描文档得到训练文档集中所有特征,一般在层次分类时使用
//参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器
bool CClassifier::Train(int nType, bool bFlag)
{
CTime startTime;
CTimeSpan totalTime;
if(bFlag)
{
InitTrain();
//生成所有候选特征项,将其保存在m_lstWordList中
GenDic();
}
CMessage::PrintStatusInfo("");
if(m_lstWordList.GetCount()==0)
return false;
if(m_lstTrainCatalogList.GetCataNum()==0)
return false;
//清空特征项列表m_lstTrainWordList
m_lstTrainWordList.InitWordList();
//为特征项列表m_lstWordList中的每个特征加权
CMessage::PrintInfo(_T("开始计算候选特征集中每个特征的类别区分度,请稍候..."));
startTime=CTime::GetCurrentTime();
FeatherWeight(m_lstWordList);
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("特征区分度计算结束,耗时")+totalTime.Format("%H:%M:%S"));
CMessage::PrintStatusInfo("");
//从特征项列表m_lstWordList中选出最优特征
CMessage::PrintInfo(_T("开始进行特征选择,请稍候..."));
startTime=CTime::GetCurrentTime();
FeatherSelection(m_lstTrainWordList);
//为最优特征集m_lstTrainWordList中的每个特征建立一个ID
m_lstTrainWordList.IndexWord();
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("特征选择结束,耗时")+totalTime.Format("%H:%M:%S"));
CMessage::PrintStatusInfo("");
//清空m_lstWordList,释放它占用的空间
m_lstWordList.InitWordList();
CMessage::PrintInfo("开始生成文档向量,请稍候...");
startTime=CTime::GetCurrentTime();
GenModel();
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("文档向量生成结束,耗时")+totalTime.Format("%H:%M:%S"));
CMessage::PrintStatusInfo("");
CMessage::PrintInfo("开始保存分类模型,请稍候...");
startTime=CTime::GetCurrentTime();
WriteModel(m_paramClassifier.m_txtResultDir+"\\model.prj",nType);
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("保存分类模型结束,耗时")+totalTime.Format("%H:%M:%S"));
//训练SVM分类器必须在保存训练文档的文档向量后进行
if(nType>0)
{
CMessage::PrintInfo("开始训练SVM,请稍候...");
m_lstTrainCatalogList.InitCatalogList(2); //删除文档向量所占用的空间
startTime=CTime::GetCurrentTime();
TrainSVM();
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("SVM分类器训练结束,耗时")+totalTime.Format("%H:%M:%S"));
CMessage::PrintStatusInfo("");
}
//为分类做好准备,否则不能进行分类
Prepare();
CMessage::PrintStatusInfo("");
return TRUE;
}
void CClassifier::TrainSVM()
{
CString str;
CTime tmStart;
CTimeSpan tmSpan;
m_paramClassifier.m_strModelFile="model";
for(int i=1;i<=m_lstTrainCatalogList.GetCataNum();i++)
{
tmStart=CTime::GetCurrentTime();
str.Format("正在训练第%d个SVM分类器,请稍侯...",i);
CMessage::PrintInfo(str);
m_theSVM.com_param.trainfile=m_paramClassifier.m_txtResultDir+"\\train.txt";
m_theSVM.com_param.modelfile.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i);
m_theSVM.svm_learn_main(i);
tmSpan=CTime::GetCurrentTime()-tmStart;
str.Format("第%d个SVM分类器训练完成,耗时%s!",i,tmSpan.Format("%H:%M:%S"));
CMessage::PrintInfo(str);
}
}
// fill an array of CTrain::sSortType (train word length)
// nCatalog mean the value of element of the array is the weight
// of nCatalog(as an index of catalog) for each individual word
// if nCatalog==-1 then sum weight for all catalog
void CClassifier::GenSortBuf(CWordList& wordList,sSortType *psSortBuf,int nCatalog)
{
int nTotalCata=m_lstTrainCatalogList.GetCataNum();
ASSERT(nCatalog<nTotalCata);
long lWordCount=0 ;
POSITION pos_word = wordList.GetFirstPosition();
CString str;
while(pos_word!= NULL) // for each word
{
CWordNode& wordnode = wordList.GetNext(pos_word,str);
psSortBuf[lWordCount].pclsWordNode = &wordnode;
strcpy(psSortBuf[lWordCount].word,str);
ASSERT(wordnode.m_nAllocLen==nTotalCata);
if(nCatalog==-1)
{
psSortBuf[lWordCount].dWeight+=wordnode.m_dWeight;
}
else
psSortBuf[lWordCount].dWeight=wordnode.m_pCataWeight[nCatalog];
lWordCount++;
}
}
//从m_lstWordList选出最优特征子集,存到dstWordList中
void CClassifier::FeatherSelection(CWordList& dstWordList)
{
if(m_lstWordList.GetCount()<=0) return;
dstWordList.InitWordList();
m_lstWordList.IndexWord();
sSortType *psSortBuf;
int nDistinctWordNum = m_lstWordList.GetCount();
psSortBuf = new sSortType[nDistinctWordNum ]; // the distinct number of the word
ASSERT(psSortBuf!=NULL);
long lDocNum=m_lstTrainCatalogList.GetDocNum();
for(int i=0;i<nDistinctWordNum ;i++)
{
psSortBuf[i].pclsWordNode = NULL;
psSortBuf[i].dWeight = 0;
}
// catalog indivial selecting
if(m_paramClassifier.m_nSelMode==CClassifierParam::nFSM_IndividualModel)
{
int nCatalogWordSize=m_paramClassifier.m_nWordSize;
int nTotalCata=m_lstTrainCatalogList.GetCataNum();
for(int i=0;i<nTotalCata;i++)
{
GenSortBuf(m_lstWordList,psSortBuf,i);//-1 mean sum all catalog
Sort(psSortBuf,nDistinctWordNum-1);
int nSelectWordNum=0;
for(int j=0;j<nDistinctWordNum&&nSelectWordNum<nCatalogWordSize;j++)
{
CWordNode wordNode;
if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF)
psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum);
else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF)
psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true);
wordNode.m_dWeight=psSortBuf[j].pclsWordNode->m_dWeight;
wordNode.m_lDocFreq=psSortBuf[j].pclsWordNode->m_lDocFreq;
wordNode.m_lWordFreq=psSortBuf[j].pclsWordNode->m_lWordFreq;
dstWordList.SetAt(psSortBuf[j].word,wordNode);
nSelectWordNum++;
}
}
}
// total selecting
else //if(m_paramClassifier.m_nSelMode==CClassifierParam::nFSM_GolbalMode)
{
int iWord=0;
GenSortBuf(m_lstWordList,psSortBuf,-1);//-1 mean sum all catalog
Sort(psSortBuf,nDistinctWordNum-1);
int nSelectWordNum=m_paramClassifier.m_nWordSize;
if (nSelectWordNum>nDistinctWordNum)
nSelectWordNum=nDistinctWordNum;
for(i=0;i<nSelectWordNum;i++)
{
CWordNode wordNode;
if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF)
psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum);
else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF)
psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true);
wordNode.m_dWeight=psSortBuf[i].pclsWordNode->m_dWeight;
wordNode.m_lDocFreq=psSortBuf[i].pclsWordNode->m_lDocFreq;
wordNode.m_lWordFreq=psSortBuf[i].pclsWordNode->m_lWordFreq;
dstWordList.SetAt(psSortBuf[i].word,wordNode);
}
}
delete [] psSortBuf;
}
void CClassifier::FeatherWeight(CWordList& wordList)
{
// ------------------------------------------------------------------------------
// based on document number model
int N; //总的文档数;
int N_c; //C类的文档数
int N_ft; //含有ft的文档数
int N_c_ft; //C类中含有ft的文档数
// ------------------------------------------------------------------------------
// based on word number model
long N_W; //总的词数 m_lWordNum;
long N_W_C; //C类词数 CCatalogNode.m_lTotalWordNum;
long N_W_f_t; //f_t出现的总次数
long N_W_C_f_t;//C类中f_t出现的次数
// ------------------------------------------------------------------------------
double P_c_ft,P_c_n_ft,P_n_c_ft,P_n_c_n_ft;
double P_c,P_n_c;
double P_ft,P_n_ft;
// ------------------------------------------------------------------------------
POSITION pos_cata,pos_word;
CString strWord;
// calculate the weight of each word to all catalog
N = m_lstTrainCatalogList.GetDocNum();
N_W = wordList.GetWordNum();
int nTotalCata=m_lstTrainCatalogList.GetCataNum();
pos_word = wordList.GetFirstPosition();
while(pos_word!= NULL) // for each word
{
CWordNode& wordnode = wordList.GetNext(pos_word,strWord);
wordnode.m_dWeight=0;
ASSERT(wordnode.m_pCataWeight);
CMessage::PrintStatusInfo("特征:"+strWord+"...");
N_ft = wordnode.GetDocNum();
N_W_f_t = wordnode.GetWordNum();
int nCataCount=0;
pos_cata = m_lstTrainCatalogList.GetFirstPosition();
while(pos_cata!=NULL) // for each catalog
{
CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata);
N_c = catanode.GetDocNum();
N_W_C = catanode.m_lTotalWordNum;
N_c_ft = wordnode.GetCataDocNum(catanode.m_idxCata);
N_W_C_f_t =wordnode.GetCataWordNum(catanode.m_idxCata);
// calculation model
if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpWordMode)
{
P_c = 1.0 * N_W_C /N_W;
P_ft = 1.0 * N_W_f_t/N_W;
P_c_ft = 1.0 * N_W_C_f_t/N_W;
}
else //if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpDocMode)
{
P_c = 1.0 * N_c /N;
P_ft = 1.0 * N_ft/N;
P_c_ft = 1.0 * N_c_ft/N;
}
P_n_c = 1 - P_c;
P_n_ft = 1 - P_ft;
P_n_c_ft = P_ft - P_c_ft;
P_c_n_ft = P_c - P_c_ft;
P_n_c_n_ft = P_n_ft - P_c_n_ft;
wordnode.m_pCataWeight[nCataCount]=0;
// feature selection model
if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_XXMode)
{
// Right half of IG
if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) )
{
wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) );
}
}
else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_MIMode)
{
// Mutual Informaiton feature selection
if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) )
{
wordnode.m_pCataWeight[nCataCount]+= P_c * log( P_c_ft/(P_c * P_ft) );
}
}
else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_CEMode)
{
// Cross Entropy for text feature selection
if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) )
{
wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) );
}
}
else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_X2Mode)
{
// X^2 Statistics feature selection
if ( (fabs(P_n_c * P_ft * P_n_ft) > dZero) )
{
wordnode.m_pCataWeight[nCataCount]+= (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) * (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) / ( P_ft * P_n_c * P_n_ft);
}
}
else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_WEMode)
{
// Weight of Evielence for text feature selection
double odds_c_ft;
double odds_c;
double P_c_inv_ft=P_c_ft/P_ft;
if( fabs(P_c_inv_ft) < dZero )
odds_c_ft = 1.0 / ( N * N -1);
else if ( fabs(P_c_inv_ft-1) < dZero )
odds_c_ft = N * N -1;
else
odds_c_ft = P_c_inv_ft / (1.0 - P_c_inv_ft);
if( fabs(P_c) < dZero )
odds_c = 1.0 / ( N * N -1);
else if ( fabs(P_c-1) < dZero )
odds_c = N * N -1;
else
odds_c = P_c / (1.0 - P_c);
if( fabs(odds_c) > dZero && fabs(odds_c_ft) > dZero )
wordnode.m_pCataWeight[nCataCount]+= P_c * P_ft * fabs( log(odds_c_ft / odds_c) );
}
else //if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_IGMode)
{
// Information gain feature selection
if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) )
{
wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) );
}
if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) )
{
wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) );
}
}
wordnode.m_dWeight+=wordnode.m_pCataWeight[nCataCount];
nCataCount++;
}
ASSERT(nCataCount==nTotalCata);
}
CMessage::PrintStatusInfo("");
}
//计算每一篇训练文档向量的每一维的权重
void CClassifier::ComputeWeight(bool bMult)
{
long lWordNum=m_lstTrainWordList.GetCount();
if(m_lstTrainWordList.GetCount()<=0) return;
long lDocNum=m_lstTrainCatalogList.GetDocNum();
if(lDocNum<=0) return;
m_lstTrainWordList.ComputeWeight(lDocNum,bMult);
double sum=0.0;
int i=0;
POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
while(pos_cata != NULL) // for each catalog
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -