⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classifier.cpp

📁 基于径向基函数的神经网络文本自动分类系统。
💻 CPP
📖 第 1 页 / 共 4 页
字号:
	{
		//取类列表中的每一个类
		CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata);
		POSITION pos_doc  = cataNode.GetFirstPosition();
		while(pos_doc!=NULL)
		{
			CDocNode& docNode=cataNode.GetNext(pos_doc);

			sum=0.0;
			for(i=0;i<docNode.m_nAllocLen;i++)
			{
				docNode.m_sWeightSet[i].s_dWeight*=docNode.m_sWeightSet[i].s_tfi;
				sum+=(docNode.m_sWeightSet[i].s_dWeight*docNode.m_sWeightSet[i].s_dWeight);
			}
			sum=sqrt(sum);
			for(i=0;i<docNode.m_nAllocLen;i++)
				docNode.m_sWeightSet[i].s_dWeight/=sum;

			CMessage::PrintStatusInfo("计算文档"+docNode.m_strDocName+"向量每一维的权重");
		}
	}
}

void CClassifier::QuickSort(sSortType *psData, int iLo,int iHi)
{
    int Lo, Hi;
	double Mid;
	sSortType	t;
    Lo = iLo;
    Hi = iHi;
    Mid = psData[(Lo + Hi)/2].dWeight;
    do
	{
		while(psData[Lo].dWeight > Mid) Lo++;
		while(psData[Hi].dWeight < Mid) Hi--;
		if(Lo <= Hi)
		{
			t = psData[Lo];
			psData[Lo]=psData[Hi];
			psData[Hi]=t;
			Lo++;
			Hi--;
		}
	}while(Hi>Lo);
    if(Hi > iLo) QuickSort(psData, iLo, Hi);
    if(Lo < iHi) QuickSort(psData, Lo, iHi);
}

void CClassifier::Sort(sSortType *psData,int nSize)
{
	QuickSort(psData,0,nSize);
}


// Give m_lstWordList & m_lstTrainCatalogList
// Output the present vector of each document;
// bFlag=false 层次分类的时候使用
void CClassifier::GenModel()
{
	CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount());
	POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
	while(pos_cata != NULL)  // for each catalog 
	{
		//取类列表中的每一个类
		CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata);
		POSITION pos_doc  = cataNode.GetFirstPosition();
		while(pos_doc!=NULL)
		{
			CDocNode& docNode=cataNode.GetNext(pos_doc);
			if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
				docNode.ScanChineseWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList);
			else
				docNode.ScanEnglishWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList,m_paramClassifier.m_bStem);
			docNode.GenDocVector();
			CMessage::PrintStatusInfo("生成文档"+docNode.m_strDocName+"的文档向量");
		}
	}
	CDocNode::DeallocTempBuffer();
}


// generate original dictionary (the largest one)
// form train files
bool CClassifier::GenDic()
{
	m_lstWordList.InitWordList();
	CTime startTime;
	CTimeSpan totalTime;

	startTime=CTime::GetCurrentTime();
	CMessage::PrintInfo(_T("分词程序初始化,请稍候..."));	
	if(!g_wordSeg.InitWorgSegment(theApp.m_strPath.GetBuffer(0),m_paramClassifier.m_nLanguageType))
	{
		CMessage::PrintError(_T("分词程序初始化失败!"));
		return false;
	}
	if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
		g_wordSeg.SetSegSetting(CWordSegment::uPlace);
	totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("分词程序初始化结束,耗时")+totalTime.Format("%H:%M:%S"));


	startTime=CTime::GetCurrentTime();
	CMessage::PrintInfo(_T("开始扫描训练文档,请稍候..."));
	if(m_lstTrainCatalogList.BuildLib(m_paramClassifier.m_txtTrainDir)<=0)
	{
		CMessage::PrintError("训练文档的总数为0!");
		return false;
	}

	CString strFileName;
	POSITION pos_cata=m_lstTrainCatalogList.GetFirstPosition();
	int nCount,nCataNum;
	nCataNum=m_lstTrainCatalogList.GetCataNum();
	while(pos_cata!=NULL)
	{
		CCatalogNode& catalognode=m_lstTrainCatalogList.GetNext(pos_cata);
		POSITION pos_doc=catalognode.GetFirstPosition();
		while(pos_doc!=NULL)
		{
			CDocNode& docnode=catalognode.GetNext(pos_doc);
			CMessage::PrintStatusInfo(_T("扫描文档")+docnode.m_strDocName);

			if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
				nCount=docnode.ScanChinese(catalognode.m_strDirName.GetBuffer(0),
							m_lstWordList,nCataNum,catalognode.m_idxCata);
			else
				nCount=docnode.ScanEnglish(catalognode.m_strDirName.GetBuffer(0),
							m_lstWordList,nCataNum,catalognode.m_idxCata,
							m_paramClassifier.m_bStem);
			if(nCount==0)
			{
				CMessage::PrintError("文件"+catalognode.m_strDirName+"\\"+docnode.m_strDocName+"无内容!");
				continue;
			}
			else if(nCount<0)
			{
				CMessage::PrintError("文件"+catalognode.m_strDirName+"\\"+docnode.m_strDocName+"无法打开!");
				continue;
			}
			catalognode.m_lTotalWordNum+=nCount;// information collection point
		}
	}
	g_wordSeg.FreeWordSegment();
	totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("扫描训练文档结束,耗时")+totalTime.Format("%H:%M:%S"));
	return true;
}

void CClassifier::InitTrain()
{
	m_lstTrainWordList.InitWordList();
	m_lstTrainCatalogList.InitCatalogList();
	m_lstWordList.InitWordList();
}

//参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器
bool CClassifier::WriteModel(CString strFileName, int nType)
{
	CFile fOut;
	if( !fOut.Open(strFileName,CFile::modeCreate | CFile::modeWrite) )
	{
		CMessage::PrintError("无法创建文件"+strFileName+"!");
		return false;
	}

	CArchive ar(&fOut,CArchive::store);	
	if(nType==0)
	{
		m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\features.dat");
		m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\\features.txt");
		m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\train.dat");
		m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\\train.txt");
		m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\\params.dat");

		ar<<dwModelFileID;
		ar<<CString("params.dat");
		ar<<CString("features.dat");
		ar<<CString("train.dat");
	}
	else
	{
		m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\features.dat");
		m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\\features.txt");
		m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\train.dat",1);
		m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\\train.txt");
		m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\\params.dat");

		m_theSVM.com_param.classifier_num=m_lstTrainCatalogList.GetCataNum();
		m_theSVM.com_param.trainfile="train.txt";
		m_theSVM.com_param.resultpath=m_paramClassifier.m_txtResultDir;
		m_theSVM.com_param.DumpToFile(m_paramClassifier.m_txtResultDir+"\\svmparams.dat");

		ar<<dwModelFileID;
		ar<<CString("params.dat");
		ar<<CString("features.dat");
		ar<<CString("train.dat");
		ar<<CString("svmparams.dat");
	}
	ar.Close();
	fOut.Close(); 
	return true;
}

bool CClassifier::OpenModel(CString strFileName)
{
	CFile fIn;
	if(!fIn.Open(strFileName,CFile::modeRead))
	{
		CMessage::PrintError("无法打开文件"+strFileName+"!") ;
		return false;
	}
	CTime startTime=CTime::GetCurrentTime();
	CMessage::PrintInfo(_T("正在打开分类模型文件,请稍候..."));

	CArchive ar(&fIn,CArchive::load);
	CString str,strPath;
	DWORD dwFileID;
	
	//读入文件格式标识符
	strPath=strFileName.Left(strFileName.ReverseFind('\\'));
	ar>>dwFileID;
	if(dwFileID!=dwModelFileID)
	{
		ar.Close();
		fIn.Close();
		CMessage::PrintError("分类模型文件的格式不正确!");
		return false;
	}

	ar>>str;
	if(!m_paramClassifier.GetFromFile(strPath+"\\"+str))
	{
		CMessage::PrintError(_T("无法打开训练参数文件"+str+"!"));
		return false;
	}
	m_paramClassifier.m_txtResultDir=strPath;

	if(m_paramClassifier.m_nClassifierType==0)
	{
		ar>>str;
		m_lstTrainWordList.InitWordList();
		if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str))
		{
			CMessage::PrintError(_T("无法打开特征类表文件"+str+"!"));
			return false;
		}
		ar>>str;
		m_lstTrainCatalogList.InitCatalogList();
		if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str))
		{
			CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!"));
			return false;
		}

	}
	else
	{
		ar>>str;
		m_lstTrainWordList.InitWordList();
		if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str))
		{
			CMessage::PrintError(_T("无法打开特征类表文件"+str+"!"));
			return false;
		}
		//对于SVM分类起来说m_lstTrainCatalogList其实没用
		//此处读入它只是为了在CLeftViw中显示某些统计信息时使用
		ar>>str;
		m_lstTrainCatalogList.InitCatalogList();
		if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str))
		{
			CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!"));
			return false;
		}
		ar>>str;
		if(!m_theSVM.com_param.GetFromFile(strPath+"\\"+str))
		{
			CMessage::PrintError(_T("无法打开SVM训练参数文件"+str+"!"));
			return false;
		}
		m_theSVM.com_param.trainfile=strPath+"\\train.txt";
		m_theSVM.com_param.resultpath=strPath;
	}
	ar.Close();
	fIn.Close();

	Prepare();
	CTimeSpan totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("分类模型文件已经打开,耗时")+totalTime.Format("%H:%M:%S")+"\r\n");
	
	str.Empty();
	m_paramClassifier.GetParamString(str);
	CMessage::PrintInfo(str);
	return true;	
}

bool CClassifier::Classify()
{
	m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\\classes.txt");
	CTime startTime;
	CTimeSpan totalTime;
	startTime=CTime::GetCurrentTime();
	CMessage::PrintInfo(_T("正在扫描测试文档,请稍候..."));
	if(m_paramClassifier.m_bEvaluation)
	{
		//vBuildLib方法中已经清空了g_lstTestCatalogList,所以此处无需再对其初始化
		m_lstTestCatalogList.BuildLib(m_paramClassifier.m_strTestDir);
		if(!m_lstTestCatalogList.BuildCatalogID(m_lstTrainCatalogList))
		{
			CMessage::PrintError("测试文件中包含有无法识别的类别!");
			return false;
		}
	}
	else
	{
		m_lstTestCatalogList.InitCatalogList();
		CCatalogNode catalognode;
		catalognode.m_strDirName=m_paramClassifier.m_strTestDir;
		catalognode.m_strCatalogName="测试文档";
		catalognode.m_idxCata=-1;
		POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode);
		CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp);
		cataTemp.SetStartDocID(0);
		cataTemp.ScanDirectory(m_paramClassifier.m_strTestDir);
	}
	if(m_lstTestCatalogList.GetDocNum()<=0)
	{
		CMessage::PrintError("测试文件总数为0!\r\n如果不需要对分类结果进行评价时,分类文档必须在\"分类文档目录\"下,而不是它的子目录下!");
		return false;
	}
	totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("扫描测试文档结束,耗时")+totalTime.Format("%H:%M:%S"));

	startTime=CTime::GetCurrentTime();
	CMessage::PrintInfo(_T("正在对测试文档进行分类,请稍候..."));
	long lCorrect=0,lUnknown=0;
	lUnknown=Classify(m_lstTestCatalogList);


	lCorrect=SaveResults(m_lstTestCatalogList,m_paramClassifier.m_strResultDir+"\\results.txt");
	long lTotalNum=m_lstTestCatalogList.GetDocNum()-lUnknown;
	CString str;
	totalTime=CTime::GetCurrentTime()-startTime;
	CMessage::PrintInfo(_T("测试文档分类结束,耗时")+totalTime.Format("%H:%M:%S"));
	if (lUnknown>0) 
	{
		str.Format("无法分类的文档数%d:",lUnknown);
		CMessage::PrintInfo(str);
	}
	if(m_paramClassifier.m_bEvaluation&&lTotalNum>0&&lCorrect>0)
		str.Format("测试文档总数:%d,准确率:%f",m_lstTestCatalogList.GetDocNum(),(float)(lCorrect)/(float)(lTotalNum));
	else
		str.Format("测试文档总数:%d",m_lstTestCatalogList.GetDocNum());
	CMessage::PrintInfo(str);
	return true;
}

//对Smart格式的文档进行分类
bool CClassifier::ClassifySmart()
{
	m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\\classes.txt");
	m_lstTestCatalogList.InitCatalogList();
	CCatalogNode catalognode;
	catalognode.m_strDirName=m_paramClassifier.m_strTestDir;
	catalognode.m_strCatalogName="测试文档";
	catalognode.m_idxCata=-1;
	POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode);
	CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp);

	FILE *stream1,*stream2;
	if( (stream1 = fopen( m_paramClassifier.m_strTestDir, "r" )) == NULL )
	{
		CMessage::PrintError("无法打开文件"+m_paramClassifier.m_strTestDir+"!");
		return false;
	}

	//如果是SVM分类器,则需要先将所有测试文档转换为向量,保存到文件test.dat
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
	{
		m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\\test.dat";
		if((stream2=fopen(m_theSVM.com_param.classifyfile,"w"))==NULL)
		{
			CMessage::PrintError("无法创建测试文档向量文件"+m_theSVM.com_param.classifyfile+"!");
			return false;
		}
	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -