📄 classifier.cpp
字号:
nCount=docNode.ScanChineseStringWithDict(pPath,m_lstTrainWordList);
else
nCount=docNode.ScanEnglishStringWithDict(pPath,m_lstTrainWordList,
m_paramClassifier.m_bStem);
}
if((m_lDocNum>0)&&(nCount>0))
{
DOC doc;
CString str;
docNode.GenDocVector(doc);
docNode.AllocResultsBuffer(m_nClassNum);
for(int i=0;i<m_nClassNum;i++)
{
str.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i+1);
theClassifier.m_theSVM.com_param.modelfile=str;
docNode.m_pResults[i]=theClassifier.m_theSVM.svm_classify(doc);
}
free(doc.words);
return true;
}
else
return false;
}
void CClassifier::Prepare()
{
CTime startTime;
CTimeSpan totalTime;
if(m_pDocs!=NULL)
{
m_lDocNum=0;
free(m_pDocs);
m_pDocs=NULL;
}
if(m_pSimilarityRatio!=NULL)
{
m_lDocNum=0;
delete[] m_pSimilarityRatio;
m_pSimilarityRatio=NULL;
}
m_nClassNum=m_lstTrainCatalogList.GetCataNum();
m_lDocNum=m_lstTrainCatalogList.GetDocNum();
if(m_paramClassifier.m_nKNN>m_lDocNum) m_paramClassifier.m_nKNN=m_lDocNum;
m_pSimilarityRatio=new DocWeight[m_lDocNum];
m_pDocs=(DocCatalog*)malloc(sizeof(DocCatalog)*m_lDocNum);
long num=0;
POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
while(pos_cata != NULL) // for each catalog
{
CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata);
short idxCata=catanode.m_idxCata;
POSITION pos_doc = catanode.GetFirstPosition();
while(pos_doc!=NULL)
{
CDocNode& docnode=catanode.GetNext(pos_doc);
m_pDocs[num].pDocNode=&docnode;
m_pDocs[num].nCataID=idxCata;
num++;
}
}
CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount());
}
void CClassifier::Sort(DocWeight *pData,int nSize)
{
QuickSort(pData,0,nSize);
}
void CClassifier::QuickSort(DocWeight *psData, int iLo,int iHi)
{
int Lo, Hi;
double Mid;
DocWeight t;
Lo = iLo;
Hi = iHi;
Mid = psData[(Lo + Hi)/2].dWeight;
do
{
while(psData[Lo].dWeight > Mid) Lo++;
while(psData[Hi].dWeight < Mid) Hi--;
if(Lo <= Hi)
{
t = psData[Lo];
psData[Lo]=psData[Hi];
psData[Hi]=t;
Lo++;
Hi--;
}
}while(Hi>Lo);
if(Hi > iLo) QuickSort(psData, iLo, Hi);
if(Lo < iHi) QuickSort(psData, Lo, iHi);
}
//将分类结果保存到文件strFileName中,返回正确分类的文档总数
//如果分类参数中要求拷贝文件到结果类别目录,则执行拷贝操作
//参数typeArray只有在多类分类,且需要进行评价的时候才会用到
long CClassifier::SaveResults(CCatalogList &cataList, CString strFileName, CStringArray *aryType)
{
FILE *stream;
if( (stream = fopen(strFileName, "w+" )) == NULL )
{
CMessage::PrintError("无法创建分类结果文件"+strFileName+"!");
return -1;
}
CString str1,str2;
long lCorrect=0;
long docID=0;
int i;
char path[MAX_PATH];
CArray<short,short> aryResult;
CArray<short,short> aryAnswer;
double dThreshold=(double)m_paramClassifier.m_dThreshold/100.0;
POSITION pos_cata=cataList.GetFirstPosition();
while(pos_cata!=NULL)
{
CCatalogNode& cataNode=cataList.GetNext(pos_cata);
short id=cataNode.m_idxCata;
strcpy(path,cataNode.m_strDirName.GetBuffer(0));
POSITION pos_doc=cataNode.GetFirstPosition();
while(pos_doc!=NULL)
{
CDocNode& docNode=cataNode.GetNext(pos_doc);
if(docNode.m_nCataID<0) continue;
str1.Empty();
str2.Empty();
//如果是多类分类
if(m_paramClassifier.m_nClassifyType==CClassifierParam::nFT_Multi)
{
MultiCategory(docNode,aryResult,dThreshold);
//如果需要将分类结果拷贝到分类结果目录
if(m_paramClassifier.m_bCopyFiles)
{
for(i=0;i<aryResult.GetSize();i++)
{
m_lstTrainCatalogList.GetCataName(aryResult[i],str1);
str2=str2+str1+",";
if(m_paramClassifier.m_bCopyFiles)
CopyFile(docNode.m_strDocName.GetBuffer(0),path,
m_paramClassifier.m_strResultDir.GetBuffer(0),str1.GetBuffer(0));
}
str2.SetAt(str2.GetLength()-1,' ');
}
//如果需要对分类结果进行评价
if(m_paramClassifier.m_bEvaluation)
{
m_lstTrainCatalogList.GetCataIDArrayFromString(aryType->GetAt(docID).GetBuffer(0),aryAnswer);
//得到答案字符串
for(i=0;i<aryAnswer.GetSize();i++)
{
str1.Format("%d",aryAnswer[i]);
str2+=(str1+",");
}
str2.SetAt(str2.GetLength()-1,' ');
fprintf(stream,"%d %s %s",docID,docNode.m_strDocName,str2);
//得到分类结果字符串
str2.Empty();
for(i=0;i<aryResult.GetSize();i++)
{
str1.Format("%d",aryResult[i]);
str2+=(str1+",");
}
str2=str2.Left(str2.GetLength()-1);
fprintf(stream,"%s\n",str2);
}
else
{
if(str2.IsEmpty())
{
for(i=0;i<aryResult.GetSize();i++)
{
m_lstTrainCatalogList.GetCataName(aryResult[i],str1);
str2=str2+str1+",";
}
str2.SetAt(str2.GetLength()-1,' ');
}
fprintf(stream,"%d\t%s\t\t%s\n",docID,docNode.m_strDocName,str2);
}
}
//如果是单类分类
else
{
//如果需要将分类结果拷贝到分类结果目录
if(m_paramClassifier.m_bCopyFiles)
{
m_lstTrainCatalogList.GetCataName(docNode.m_nCataID,str1);
CopyFile(docNode.m_strDocName.GetBuffer(0),
cataNode.m_strDirName.GetBuffer(0),
m_paramClassifier.m_strResultDir.GetBuffer(0),
str1.GetBuffer(0));
}
//如果需要对分类结果进行评价
if(m_paramClassifier.m_bEvaluation)
{
if(docNode.m_nCataID==id) lCorrect++;
fprintf(stream,"%d %s %d %d\n",docID,docNode.m_strDocName,
cataNode.m_idxCata,docNode.m_nCataID);
}
else
{
if(str1.IsEmpty()) m_lstTrainCatalogList.GetCataName(docNode.m_nCataID,str1);
fprintf(stream,"%d\t%s\t\t%s\n",docID,docNode.m_strDocName,str1);
}
}
docID++;
}
}
fclose(stream);
return lCorrect;
}
void CClassifier::CopyFile(char *pFileName, char *pSource, char *pTarget, char *pCatalog)
{
char targetFile[MAX_PATH];
strcpy(targetFile,pTarget);
strcat(targetFile,"\\");
strcat(targetFile,pCatalog);
if(_chdir(targetFile)<0)
if(_mkdir(targetFile)<0) return;
char sourceFile[MAX_PATH];
strcpy(sourceFile,pSource);
strcat(sourceFile,"\\");
strcat(sourceFile,pFileName);
strcat(targetFile,"\\");
strcat(targetFile,pFileName);
::CopyFile(sourceFile,targetFile,false);
}
void CClassifier::Evaluate(CString strPath)
{
CString strFileName=strPath;
strFileName=strFileName+"\\multieval.exe ";
strFileName=strFileName+theClassifier.m_paramClassifier.m_strResultDir+"\\classes.txt ";
strFileName=strFileName+theClassifier.m_paramClassifier.m_strResultDir+"\\results.txt";
if(WinExec(strFileName,SW_SHOWNORMAL)<32)
AfxMessageBox("分类结果评测程序不存在!");
}
short CClassifier::SingleCategory(CDocNode &docNode)
{
short nCataID=-1;
double *pSimRatio=docNode.m_pResults;
//得到文档的所属类别nMaxCatID
double dMaxNum=pSimRatio[0];
nCataID=0;
for(int i=1;i<m_nClassNum;i++)
{
if(pSimRatio[i]>dMaxNum)
{
dMaxNum=pSimRatio[i];
nCataID=i;
}
}
docNode.m_nCataID=nCataID;
return nCataID;
}
short CClassifier::SVMCategory(char *pPath, CDocNode &docNode, bool bFile)
{
short nCataID=-1;
if(SVMClassify(pPath,docNode,bFile)) nCataID=SingleCategory(docNode);
return nCataID;
}
short CClassifier::SVMCategory(char *file, bool bFile)
{
CDocNode docNode;
short id=-1;
if(bFile)
{
char *fname=strrchr(file,'\\');
if(fname==NULL) return -1;
docNode.m_strDocName=(fname+1);
char path[MAX_PATH];
strncpy(path,file,fname-file);
path[fname-file]=0;
id=SVMCategory(path,docNode,bFile);
}
else
id=SVMCategory(file,docNode,bFile);
return id;
}
void CClassifier::SVMClassifyVectorFile(CString strFileName)
{
//为了计算分类结果,用来保存每个分类器分类结果的数组
CTime startTime;
CTimeSpan totalTime;
CString str;
long num=m_lstTestCatalogList.GetDocNum(),lDocNum=0;
double *fpWeight=new double[num];
POSITION pos_doc, pos_cata;
m_theSVM.com_param.classifyfile=strFileName;
for(int i=1;i<=m_nClassNum;i++)
{
memset(fpWeight,0,sizeof(double)*num);
startTime=CTime::GetCurrentTime();
str.Format("正在使用第%d个SVM分类器对文档进行分类,请稍候...",i);
CMessage::PrintInfo(str);
str.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i);
m_theSVM.com_param.modelfile=str;
m_theSVM.svm_classify(i,fpWeight);
//将文档和当前类别的相似度赋给m_pResults[i-1]
lDocNum=0;
pos_cata=m_lstTestCatalogList.GetFirstPosition();
while(pos_cata!=NULL)
{
CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata);
pos_doc=catalognode.GetFirstPosition();
while(pos_doc!=NULL)
{
CDocNode& docnode=catalognode.GetNext(pos_doc);
docnode.AllocResultsBuffer(m_nClassNum);
docnode.m_pResults[i-1]=fpWeight[lDocNum];
lDocNum++;
}
}
totalTime=CTime::GetCurrentTime()-startTime;
str.Format("第%d个SVM分类器分类结束,耗时",i);
CMessage::PrintInfo(str+totalTime.Format("%H:%M:%S"));
}
delete[] fpWeight;
//计算和文档的相似度最大的类别
pos_cata=m_lstTestCatalogList.GetFirstPosition();
while(pos_cata!=NULL)
{
CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata);
pos_doc=catalognode.GetFirstPosition();
while(pos_doc!=NULL)
{
CDocNode& docnode=catalognode.GetNext(pos_doc);
docnode.m_nCataID=SingleCategory(docnode);
}
}
}
short CClassifier::GetCategory(char *file, bool bFile)
{
short result=-1;
if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
result=KNNCategory(file,bFile);
else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
result=SVMCategory(file,bFile);
return result;
}
short CClassifier::GetCategory(char *path, CDocNode &docNode, bool bFile)
{
short result=-1;
if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
result=KNNCategory(path,docNode,bFile);
else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
result=SVMCategory(path,docNode,bFile);
return result;
}
bool CClassifier::Classify(char *path, CDocNode &docNode, bool bFile)
{
bool result=false;
if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
result=KNNClassify(path,docNode,bFile);
else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
result=SVMClassify(path,docNode,bFile);
return result;
}
long CClassifier::Classify(CCatalogList &cataList)
{
long lUnknown=0;
if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN)
lUnknown=KNNClassify(m_lstTestCatalogList);
else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
lUnknown=SVMClassify(m_lstTestCatalogList);
else
CMessage::PrintError("无法确定分类器的类型!");
return lUnknown;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -