⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classifier.cpp

📁 基于径向基函数的神经网络文本自动分类系统。
💻 CPP
📖 第 1 页 / 共 4 页
字号:
			nCount=docNode.ScanChineseStringWithDict(pPath,m_lstTrainWordList);
		else
			nCount=docNode.ScanEnglishStringWithDict(pPath,m_lstTrainWordList,
											m_paramClassifier.m_bStem);
	}

	if((m_lDocNum>0)&&(nCount>0))
	{
		DOC doc;
		CString str;
		docNode.GenDocVector(doc);
		docNode.AllocResultsBuffer(m_nClassNum);
		for(int i=0;i<m_nClassNum;i++)
		{
			str.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i+1);
			theClassifier.m_theSVM.com_param.modelfile=str;
			docNode.m_pResults[i]=theClassifier.m_theSVM.svm_classify(doc);
		}
		free(doc.words);
		return true;
	}
	else
		return false;
}

void CClassifier::Prepare()
{
	CTime startTime;
	CTimeSpan totalTime;

	if(m_pDocs!=NULL)
	{
		m_lDocNum=0;
		free(m_pDocs);
		m_pDocs=NULL;
	}
	if(m_pSimilarityRatio!=NULL)
	{
		m_lDocNum=0;
		delete[] m_pSimilarityRatio;
		m_pSimilarityRatio=NULL;
	}

	m_nClassNum=m_lstTrainCatalogList.GetCataNum();
	m_lDocNum=m_lstTrainCatalogList.GetDocNum();
	if(m_paramClassifier.m_nKNN>m_lDocNum) m_paramClassifier.m_nKNN=m_lDocNum;
	m_pSimilarityRatio=new DocWeight[m_lDocNum];
	m_pDocs=(DocCatalog*)malloc(sizeof(DocCatalog)*m_lDocNum);
	long num=0;
	POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
	while(pos_cata != NULL)  // for each catalog 
	{
		CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata);
		short idxCata=catanode.m_idxCata;
		POSITION pos_doc  = catanode.GetFirstPosition();
		while(pos_doc!=NULL)
		{
			CDocNode& docnode=catanode.GetNext(pos_doc);
			m_pDocs[num].pDocNode=&docnode;
			m_pDocs[num].nCataID=idxCata;
			num++;
		}
	}
	CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount());
}

void CClassifier::Sort(DocWeight *pData,int nSize)
{
	QuickSort(pData,0,nSize);
}

void CClassifier::QuickSort(DocWeight *psData, int iLo,int iHi)
{
    int Lo, Hi;
	double Mid;
	DocWeight	t;
    Lo = iLo;
    Hi = iHi;
    Mid = psData[(Lo + Hi)/2].dWeight;
    do
	{
		while(psData[Lo].dWeight > Mid) Lo++;
		while(psData[Hi].dWeight < Mid) Hi--;
		if(Lo <= Hi)
		{
			t = psData[Lo];
			psData[Lo]=psData[Hi];
			psData[Hi]=t;
			Lo++;
			Hi--;
		}
	}while(Hi>Lo);
    if(Hi > iLo) QuickSort(psData, iLo, Hi);
    if(Lo < iHi) QuickSort(psData, Lo, iHi);
}

//将分类结果保存到文件strFileName中,返回正确分类的文档总数
//如果分类参数中要求拷贝文件到结果类别目录,则执行拷贝操作
//参数typeArray只有在多类分类,且需要进行评价的时候才会用到
long CClassifier::SaveResults(CCatalogList &cataList, CString strFileName, CStringArray *aryType)
{
	FILE *stream;
	if( (stream = fopen(strFileName, "w+" )) == NULL )
	{
		CMessage::PrintError("无法创建分类结果文件"+strFileName+"!");
		return -1;
	}

	CString str1,str2;
	long lCorrect=0;
	long docID=0;
	int i;
	char path[MAX_PATH];
	CArray<short,short> aryResult;
	CArray<short,short> aryAnswer;
	double dThreshold=(double)m_paramClassifier.m_dThreshold/100.0;

	POSITION pos_cata=cataList.GetFirstPosition();
	while(pos_cata!=NULL)
	{
		CCatalogNode& cataNode=cataList.GetNext(pos_cata);
		short id=cataNode.m_idxCata;
		strcpy(path,cataNode.m_strDirName.GetBuffer(0));
		POSITION pos_doc=cataNode.GetFirstPosition();
		while(pos_doc!=NULL)
		{
			CDocNode& docNode=cataNode.GetNext(pos_doc);
			if(docNode.m_nCataID<0) continue;
			str1.Empty();
			str2.Empty();
			//如果是多类分类
			if(m_paramClassifier.m_nClassifyType==CClassifierParam::nFT_Multi)
			{
				MultiCategory(docNode,aryResult,dThreshold);
				//如果需要将分类结果拷贝到分类结果目录
				if(m_paramClassifier.m_bCopyFiles)
				{
					for(i=0;i<aryResult.GetSize();i++)
					{
						m_lstTrainCatalogList.GetCataName(aryResult[i],str1);
						str2=str2+str1+",";
						if(m_paramClassifier.m_bCopyFiles)
							CopyFile(docNode.m_strDocName.GetBuffer(0),path,
							m_paramClassifier.m_strResultDir.GetBuffer(0),str1.GetBuffer(0));
					}
					str2.SetAt(str2.GetLength()-1,' ');
				}
				//如果需要对分类结果进行评价
				if(m_paramClassifier.m_bEvaluation)
				{
					m_lstTrainCatalogList.GetCataIDArrayFromString(aryType->GetAt(docID).GetBuffer(0),aryAnswer);
					//得到答案字符串
					for(i=0;i<aryAnswer.GetSize();i++)
					{
						str1.Format("%d",aryAnswer[i]);
						str2+=(str1+",");
					}
					str2.SetAt(str2.GetLength()-1,' ');
					fprintf(stream,"%d %s %s",docID,docNode.m_strDocName,str2);

					//得到分类结果字符串
					str2.Empty();
					for(i=0;i<aryResult.GetSize();i++)
					{
						str1.Format("%d",aryResult[i]);
						str2+=(str1+",");
					}
					str2=str2.Left(str2.GetLength()-1);
					fprintf(stream,"%s\n",str2);
				}
				else
				{
					if(str2.IsEmpty())
					{
						for(i=0;i<aryResult.GetSize();i++)
						{
							m_lstTrainCatalogList.GetCataName(aryResult[i],str1);
							str2=str2+str1+",";
						}
						str2.SetAt(str2.GetLength()-1,' ');
					}
					fprintf(stream,"%d\t%s\t\t%s\n",docID,docNode.m_strDocName,str2);
				}
			}
			//如果是单类分类
			else
			{
				//如果需要将分类结果拷贝到分类结果目录
				if(m_paramClassifier.m_bCopyFiles)
				{
					m_lstTrainCatalogList.GetCataName(docNode.m_nCataID,str1);
					CopyFile(docNode.m_strDocName.GetBuffer(0),
						cataNode.m_strDirName.GetBuffer(0),
						m_paramClassifier.m_strResultDir.GetBuffer(0),
						str1.GetBuffer(0));
				}
				//如果需要对分类结果进行评价
				if(m_paramClassifier.m_bEvaluation)
				{
					if(docNode.m_nCataID==id) lCorrect++;
					fprintf(stream,"%d %s %d %d\n",docID,docNode.m_strDocName,
						cataNode.m_idxCata,docNode.m_nCataID);
				}
				else
				{
					if(str1.IsEmpty()) m_lstTrainCatalogList.GetCataName(docNode.m_nCataID,str1);
					fprintf(stream,"%d\t%s\t\t%s\n",docID,docNode.m_strDocName,str1);
				}
			}
			docID++;
		}
	}
	fclose(stream);
	return lCorrect;
}

void CClassifier::CopyFile(char *pFileName, char *pSource, char *pTarget, char *pCatalog)
{
	char targetFile[MAX_PATH];
	strcpy(targetFile,pTarget);
	strcat(targetFile,"\\");
	strcat(targetFile,pCatalog);
	if(_chdir(targetFile)<0)
		if(_mkdir(targetFile)<0) return;
	char sourceFile[MAX_PATH];
	strcpy(sourceFile,pSource);
	strcat(sourceFile,"\\");
	strcat(sourceFile,pFileName);
	strcat(targetFile,"\\");
	strcat(targetFile,pFileName);
	::CopyFile(sourceFile,targetFile,false);
}

void CClassifier::Evaluate(CString strPath)
{
	CString strFileName=strPath;
	strFileName=strFileName+"\\multieval.exe ";
	strFileName=strFileName+theClassifier.m_paramClassifier.m_strResultDir+"\\classes.txt ";
	strFileName=strFileName+theClassifier.m_paramClassifier.m_strResultDir+"\\results.txt";
	if(WinExec(strFileName,SW_SHOWNORMAL)<32)
		AfxMessageBox("分类结果评测程序不存在!");
}

short CClassifier::SingleCategory(CDocNode &docNode)
{
	short nCataID=-1;
	double *pSimRatio=docNode.m_pResults;
	//得到文档的所属类别nMaxCatID
	double dMaxNum=pSimRatio[0];
	nCataID=0;
	for(int i=1;i<m_nClassNum;i++)
	{
		if(pSimRatio[i]>dMaxNum)
		{
			dMaxNum=pSimRatio[i];
			nCataID=i;
		}
	}
	docNode.m_nCataID=nCataID;
	return nCataID;
}

short CClassifier::SVMCategory(char *pPath, CDocNode &docNode, bool bFile)
{
	short nCataID=-1;
	if(SVMClassify(pPath,docNode,bFile)) nCataID=SingleCategory(docNode);
	return nCataID;
}

short CClassifier::SVMCategory(char *file, bool bFile)
{
	CDocNode docNode;
	short id=-1;
	if(bFile)
	{
		char *fname=strrchr(file,'\\');
		if(fname==NULL) return -1;
		docNode.m_strDocName=(fname+1);

		char path[MAX_PATH];
		strncpy(path,file,fname-file);
		path[fname-file]=0;
		id=SVMCategory(path,docNode,bFile);
	}
	else
		id=SVMCategory(file,docNode,bFile);
	return id;
}

void CClassifier::SVMClassifyVectorFile(CString strFileName)
{
	//为了计算分类结果,用来保存每个分类器分类结果的数组
	CTime startTime;
	CTimeSpan totalTime;
	CString str;
	long num=m_lstTestCatalogList.GetDocNum(),lDocNum=0;
	double *fpWeight=new double[num];
	POSITION pos_doc, pos_cata;

	m_theSVM.com_param.classifyfile=strFileName;
	for(int i=1;i<=m_nClassNum;i++)
	{
		memset(fpWeight,0,sizeof(double)*num);
		startTime=CTime::GetCurrentTime();
		str.Format("正在使用第%d个SVM分类器对文档进行分类,请稍候...",i);
		CMessage::PrintInfo(str);
		str.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i);
		m_theSVM.com_param.modelfile=str;
		m_theSVM.svm_classify(i,fpWeight);
		//将文档和当前类别的相似度赋给m_pResults[i-1]
		lDocNum=0;
		pos_cata=m_lstTestCatalogList.GetFirstPosition();
		while(pos_cata!=NULL)
		{
			CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata);
			pos_doc=catalognode.GetFirstPosition();
			while(pos_doc!=NULL)
			{
				CDocNode& docnode=catalognode.GetNext(pos_doc);
				docnode.AllocResultsBuffer(m_nClassNum);
				docnode.m_pResults[i-1]=fpWeight[lDocNum];
				lDocNum++;
			}
		}
		totalTime=CTime::GetCurrentTime()-startTime;
		str.Format("第%d个SVM分类器分类结束,耗时",i);
		CMessage::PrintInfo(str+totalTime.Format("%H:%M:%S"));
	}
	delete[] fpWeight;

	//计算和文档的相似度最大的类别
	pos_cata=m_lstTestCatalogList.GetFirstPosition();
	while(pos_cata!=NULL)
	{
		CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata);
		pos_doc=catalognode.GetFirstPosition();
		while(pos_doc!=NULL)
		{
			CDocNode& docnode=catalognode.GetNext(pos_doc);
			docnode.m_nCataID=SingleCategory(docnode);
		}
	}
}

short CClassifier::GetCategory(char *file, bool bFile)
{
	short result=-1;
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
		result=KNNCategory(file,bFile);
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
		result=SVMCategory(file,bFile);
	return result;
}

short CClassifier::GetCategory(char *path, CDocNode &docNode, bool bFile)
{
	short result=-1;
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
		result=KNNCategory(path,docNode,bFile);
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
		result=SVMCategory(path,docNode,bFile);
	return result;
}

bool CClassifier::Classify(char *path, CDocNode &docNode, bool bFile)
{
	bool result=false;
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
		result=KNNClassify(path,docNode,bFile);
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
		result=SVMClassify(path,docNode,bFile);
	return result;
}

long CClassifier::Classify(CCatalogList &cataList)
{
	long lUnknown=0;
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
		lUnknown=KNNClassify(m_lstTestCatalogList);
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
		lUnknown=SVMClassify(m_lstTestCatalogList);
	else
		CMessage::PrintError("无法确定分类器的类型!");
	return lUnknown;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -