⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classifier.cpp

📁 基于径向基函数的神经网络文本自动分类系统。
💻 CPP
📖 第 1 页 / 共 4 页
字号:
// Classifier.cpp: implementation of the CClassifier class.
//
//////////////////////////////////////////////////////////////////////

#include "stdafx.h"
#include "svmcls.h"
#include "Classifier.h"
#include "WordSegment.h"
#include "Message.h"
#include <math.h>
#include <direct.h>

#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif

CClassifier theClassifier;
const DWORD CClassifier::dwModelFileID=0xFFEFFFFF;
//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////

CClassifier::CClassifier()
{
	m_pDocs=NULL;
	m_pSimilarityRatio=NULL;
	m_lDocNum=0;
	m_nClassNum=0;
}

CClassifier::~CClassifier()
{

}

//参数bGenDic=false代表无需重新扫描文档得到训练文档集中所有特征,一般在层次分类时使用
//参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器
bool CClassifier::Train(int nType, bool bFlag)
{
	CTime startTime;
	CTimeSpan totalTime;
	if(bFlag)
	{
		InitTrain();
		//生成所有候选特征项,将其保存在m_lstWordList中
		GenDic();
	}
	CMessage::PrintStatusInfo("");

	if(m_lstWordList.GetCount()==0)
		return false;
	if(m_lstTrainCatalogList.GetCataNum()==0)
		return false;

	//清空特征项列表m_lstTrainWordList
	m_lstTrainWordList.InitWordList();
	//为特征项列表m_lstWordList中的每个特征加权
	CMessage::PrintInfo(_T("开始计算候选特征集中每个特征的类别区分度,请稍候..."));
	startTime=CTime::GetCurrentTime();
	FeatherWeight(m_lstWordList);
	totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("特征区分度计算结束,耗时")+totalTime.Format("%H:%M:%S"));
	CMessage::PrintStatusInfo("");

	//从特征项列表m_lstWordList中选出最优特征
	CMessage::PrintInfo(_T("开始进行特征选择,请稍候..."));
	startTime=CTime::GetCurrentTime();
	FeatherSelection(m_lstTrainWordList);
    //为最优特征集m_lstTrainWordList中的每个特征建立一个ID
	m_lstTrainWordList.IndexWord();
	totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("特征选择结束,耗时")+totalTime.Format("%H:%M:%S"));
	CMessage::PrintStatusInfo("");

	//清空m_lstWordList,释放它占用的空间
	m_lstWordList.InitWordList();

	CMessage::PrintInfo("开始生成文档向量,请稍候...");
	startTime=CTime::GetCurrentTime();
	GenModel();
	totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("文档向量生成结束,耗时")+totalTime.Format("%H:%M:%S"));
	CMessage::PrintStatusInfo("");

	CMessage::PrintInfo("开始保存分类模型,请稍候...");
	startTime=CTime::GetCurrentTime();
	WriteModel(m_paramClassifier.m_txtResultDir+"\\model.prj",nType);
	totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("保存分类模型结束,耗时")+totalTime.Format("%H:%M:%S"));

	//训练SVM分类器必须在保存训练文档的文档向量后进行
	if(nType>0)
	{
		CMessage::PrintInfo("开始训练SVM,请稍候...");
		m_lstTrainCatalogList.InitCatalogList(2); //删除文档向量所占用的空间
		startTime=CTime::GetCurrentTime();
		TrainSVM();
		totalTime=CTime::GetCurrentTime()-startTime;
		CMessage::PrintInfo(_T("SVM分类器训练结束,耗时")+totalTime.Format("%H:%M:%S"));
		CMessage::PrintStatusInfo("");
	}
	//为分类做好准备,否则不能进行分类
	Prepare();
	CMessage::PrintStatusInfo("");
	return TRUE;
}

void CClassifier::TrainSVM()
{
	CString str;
	CTime tmStart;
	CTimeSpan tmSpan;

	m_paramClassifier.m_strModelFile="model";
	for(int i=1;i<=m_lstTrainCatalogList.GetCataNum();i++)
	{
		tmStart=CTime::GetCurrentTime();
		str.Format("正在训练第%d个SVM分类器,请稍侯...",i);
		CMessage::PrintInfo(str);
		m_theSVM.com_param.trainfile=m_paramClassifier.m_txtResultDir+"\\train.txt";
		m_theSVM.com_param.modelfile.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i);
		m_theSVM.svm_learn_main(i);
		tmSpan=CTime::GetCurrentTime()-tmStart;
		str.Format("第%d个SVM分类器训练完成,耗时%s!",i,tmSpan.Format("%H:%M:%S"));
		CMessage::PrintInfo(str);
	}	
}

// fill an array of CTrain::sSortType (train word length)
// nCatalog mean the value of element of the array is the weight
// of nCatalog(as an index of catalog) for each individual word
// if nCatalog==-1 then sum weight for all catalog
void CClassifier::GenSortBuf(CWordList& wordList,sSortType *psSortBuf,int nCatalog)
{
	int nTotalCata=m_lstTrainCatalogList.GetCataNum();
	ASSERT(nCatalog<nTotalCata);
	long		lWordCount=0 ;
	POSITION pos_word = wordList.GetFirstPosition();
	CString str;
	while(pos_word!= NULL)  // for each word
	{
		CWordNode& wordnode = wordList.GetNext(pos_word,str);
		psSortBuf[lWordCount].pclsWordNode = &wordnode; 
		strcpy(psSortBuf[lWordCount].word,str);
		ASSERT(wordnode.m_nAllocLen==nTotalCata);
		if(nCatalog==-1)
		{
			psSortBuf[lWordCount].dWeight+=wordnode.m_dWeight;
		}
		else
			psSortBuf[lWordCount].dWeight=wordnode.m_pCataWeight[nCatalog];
		lWordCount++;
	}	
}


//从m_lstWordList选出最优特征子集,存到dstWordList中
void CClassifier::FeatherSelection(CWordList& dstWordList)
{
	if(m_lstWordList.GetCount()<=0) return;
	dstWordList.InitWordList();
	m_lstWordList.IndexWord();

	sSortType	*psSortBuf;
	int			nDistinctWordNum = m_lstWordList.GetCount();
	psSortBuf = new sSortType[nDistinctWordNum ];  // the distinct number of the word 
	ASSERT(psSortBuf!=NULL);
	long lDocNum=m_lstTrainCatalogList.GetDocNum();
	for(int i=0;i<nDistinctWordNum ;i++)
	{
		psSortBuf[i].pclsWordNode = NULL;
		psSortBuf[i].dWeight	  = 0;
	}

	// catalog indivial selecting
	if(m_paramClassifier.m_nSelMode==CClassifierParam::nFSM_IndividualModel)
	{
		int nCatalogWordSize=m_paramClassifier.m_nWordSize;
		int nTotalCata=m_lstTrainCatalogList.GetCataNum();
		for(int i=0;i<nTotalCata;i++)
		{
			GenSortBuf(m_lstWordList,psSortBuf,i);//-1 mean sum all catalog
			Sort(psSortBuf,nDistinctWordNum-1);
			int nSelectWordNum=0;
			for(int j=0;j<nDistinctWordNum&&nSelectWordNum<nCatalogWordSize;j++)
			{
				CWordNode wordNode;
				if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF)
					psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum);
				else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF)
					psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true);
				wordNode.m_dWeight=psSortBuf[j].pclsWordNode->m_dWeight;
				wordNode.m_lDocFreq=psSortBuf[j].pclsWordNode->m_lDocFreq;
				wordNode.m_lWordFreq=psSortBuf[j].pclsWordNode->m_lWordFreq;
				dstWordList.SetAt(psSortBuf[j].word,wordNode);
				nSelectWordNum++;
			}
		}
	}
	// total selecting
	else //if(m_paramClassifier.m_nSelMode==CClassifierParam::nFSM_GolbalMode)
	{
		int iWord=0;
		GenSortBuf(m_lstWordList,psSortBuf,-1);//-1 mean sum all catalog
		Sort(psSortBuf,nDistinctWordNum-1);
		
		int nSelectWordNum=m_paramClassifier.m_nWordSize;
		if (nSelectWordNum>nDistinctWordNum)
			nSelectWordNum=nDistinctWordNum;

		for(i=0;i<nSelectWordNum;i++)
		{
			CWordNode wordNode;
			if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF)
				psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum);
			else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF)
				psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true);
			wordNode.m_dWeight=psSortBuf[i].pclsWordNode->m_dWeight;
			wordNode.m_lDocFreq=psSortBuf[i].pclsWordNode->m_lDocFreq;
			wordNode.m_lWordFreq=psSortBuf[i].pclsWordNode->m_lWordFreq;
			dstWordList.SetAt(psSortBuf[i].word,wordNode);
		}	
	}
	delete [] psSortBuf;
}

void CClassifier::FeatherWeight(CWordList& wordList)
{
// ------------------------------------------------------------------------------
//  based on document number model
	int			N;		//总的文档数;
	int			N_c;	//C类的文档数
	int			N_ft;	//含有ft的文档数
	int			N_c_ft;	//C类中含有ft的文档数
// ------------------------------------------------------------------------------
//  based on word number model
	long		N_W;    //总的词数					m_lWordNum;
	long		N_W_C;  //C类词数					CCatalogNode.m_lTotalWordNum;
	long		N_W_f_t; //f_t出现的总次数	
	long		N_W_C_f_t;//C类中f_t出现的次数
// ------------------------------------------------------------------------------
	double		P_c_ft,P_c_n_ft,P_n_c_ft,P_n_c_n_ft;
	double		P_c,P_n_c;
	double		P_ft,P_n_ft;

// ------------------------------------------------------------------------------
	POSITION	pos_cata,pos_word;
	CString     strWord;

	// calculate the weight of each word to all catalog
	N = m_lstTrainCatalogList.GetDocNum();
	N_W = wordList.GetWordNum();

	int nTotalCata=m_lstTrainCatalogList.GetCataNum();
	pos_word = wordList.GetFirstPosition();
	while(pos_word!= NULL)  // for each word
	{
		CWordNode& wordnode = wordList.GetNext(pos_word,strWord);
		wordnode.m_dWeight=0;

		ASSERT(wordnode.m_pCataWeight);

		CMessage::PrintStatusInfo("特征:"+strWord+"..."); 

		N_ft = wordnode.GetDocNum();  
		N_W_f_t = wordnode.GetWordNum();
		int nCataCount=0;
		pos_cata = m_lstTrainCatalogList.GetFirstPosition();
		while(pos_cata!=NULL)  // for each catalog 
		{
			CCatalogNode&	catanode = m_lstTrainCatalogList.GetNext(pos_cata);
			N_c  = catanode.GetDocNum();
			N_W_C  = catanode.m_lTotalWordNum;			
			N_c_ft = wordnode.GetCataDocNum(catanode.m_idxCata);
			N_W_C_f_t =wordnode.GetCataWordNum(catanode.m_idxCata);
			// calculation model 
			if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpWordMode)   
			{
				P_c	    = 1.0 * N_W_C /N_W;
				P_ft	= 1.0 * N_W_f_t/N_W;
				P_c_ft  = 1.0 * N_W_C_f_t/N_W;
			}
			else //if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpDocMode)
			{
				P_c			= 1.0 * N_c /N;
				P_ft		= 1.0 * N_ft/N;
				P_c_ft		= 1.0 * N_c_ft/N;
			}
			
			P_n_c		= 1 - P_c;
			P_n_ft		= 1 - P_ft;
			P_n_c_ft	= P_ft - P_c_ft;
			P_c_n_ft	= P_c - P_c_ft;
			P_n_c_n_ft  = P_n_ft - P_c_n_ft;

			wordnode.m_pCataWeight[nCataCount]=0;
			// feature selection model
			if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_XXMode)
			{
				// Right half of IG
				if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) ) 
				{
					wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) );
				}
			}
			else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_MIMode)
			{
				// Mutual Informaiton feature selection
				if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) 
				{
					wordnode.m_pCataWeight[nCataCount]+= P_c * log( P_c_ft/(P_c * P_ft) );
				}
			}
			else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_CEMode)
			{
				// Cross Entropy for text feature selection
				if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) 
				{
					wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) );
				}
			}
			else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_X2Mode)
			{
				// X^2 Statistics feature selection
				if ( (fabs(P_n_c * P_ft * P_n_ft) > dZero) ) 
				{
					wordnode.m_pCataWeight[nCataCount]+= (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) * (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) / ( P_ft * P_n_c * P_n_ft);
				}
			}
			else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_WEMode)
			{
				// Weight of Evielence for text feature selection
				double		odds_c_ft;
				double		odds_c;
				double		P_c_inv_ft=P_c_ft/P_ft;

				if( fabs(P_c_inv_ft) < dZero )
					odds_c_ft = 1.0 / ( N * N -1);
				else if ( fabs(P_c_inv_ft-1) < dZero )
					odds_c_ft = N * N -1;
				else
					odds_c_ft = P_c_inv_ft / (1.0 - P_c_inv_ft);

				if( fabs(P_c) < dZero )
					odds_c = 1.0 / ( N * N -1);
				else if ( fabs(P_c-1) < dZero )
					odds_c = N * N -1;
				else
					odds_c = P_c / (1.0 - P_c);
				if( fabs(odds_c) > dZero && fabs(odds_c_ft) > dZero )
					wordnode.m_pCataWeight[nCataCount]+= P_c * P_ft * fabs( log(odds_c_ft / odds_c) );
			}
			else //if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_IGMode) 
			{
				// Information gain feature selection
				if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) ) 
				{
					wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) );
				}
				if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) 
				{
					wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) );
				}
			}
			wordnode.m_dWeight+=wordnode.m_pCataWeight[nCataCount];
			nCataCount++;				
		}
		ASSERT(nCataCount==nTotalCata);
	}
	CMessage::PrintStatusInfo("");
}

//计算每一篇训练文档向量的每一维的权重
void CClassifier::ComputeWeight(bool bMult)
{
	long lWordNum=m_lstTrainWordList.GetCount();
	if(m_lstTrainWordList.GetCount()<=0) return;
	long lDocNum=m_lstTrainCatalogList.GetDocNum();
	if(lDocNum<=0) return;
	m_lstTrainWordList.ComputeWeight(lDocNum,bMult);

	double sum=0.0;
	int i=0;
	POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
	while(pos_cata != NULL)  // for each catalog 

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -