📄 classifier.cpp
字号:
{
//取类列表中的每一个类
CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata);
POSITION pos_doc = cataNode.GetFirstPosition();
while(pos_doc!=NULL)
{
CDocNode& docNode=cataNode.GetNext(pos_doc);
sum=0.0;
for(i=0;i<docNode.m_nAllocLen;i++)
{
docNode.m_sWeightSet[i].s_dWeight*=docNode.m_sWeightSet[i].s_tfi;
sum+=(docNode.m_sWeightSet[i].s_dWeight*docNode.m_sWeightSet[i].s_dWeight);
}
sum=sqrt(sum);
for(i=0;i<docNode.m_nAllocLen;i++)
docNode.m_sWeightSet[i].s_dWeight/=sum;
CMessage::PrintStatusInfo("计算文档"+docNode.m_strDocName+"向量每一维的权重");
}
}
}
void CClassifier::QuickSort(sSortType *psData, int iLo,int iHi)
{
int Lo, Hi;
double Mid;
sSortType t;
Lo = iLo;
Hi = iHi;
Mid = psData[(Lo + Hi)/2].dWeight;
do
{
while(psData[Lo].dWeight > Mid) Lo++;
while(psData[Hi].dWeight < Mid) Hi--;
if(Lo <= Hi)
{
t = psData[Lo];
psData[Lo]=psData[Hi];
psData[Hi]=t;
Lo++;
Hi--;
}
}while(Hi>Lo);
if(Hi > iLo) QuickSort(psData, iLo, Hi);
if(Lo < iHi) QuickSort(psData, Lo, iHi);
}
void CClassifier::Sort(sSortType *psData,int nSize)
{
QuickSort(psData,0,nSize);
}
// Give m_lstWordList & m_lstTrainCatalogList
// Output the present vector of each document;
// bFlag=false 层次分类的时候使用
void CClassifier::GenModel()
{
CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount());
POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition();
while(pos_cata != NULL) // for each catalog
{
//取类列表中的每一个类
CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata);
POSITION pos_doc = cataNode.GetFirstPosition();
while(pos_doc!=NULL)
{
CDocNode& docNode=cataNode.GetNext(pos_doc);
if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
docNode.ScanChineseWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList);
else
docNode.ScanEnglishWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList,m_paramClassifier.m_bStem);
docNode.GenDocVector();
CMessage::PrintStatusInfo("生成文档"+docNode.m_strDocName+"的文档向量");
}
}
CDocNode::DeallocTempBuffer();
}
// generate original dictionary (the largest one)
// form train files
bool CClassifier::GenDic()
{
m_lstWordList.InitWordList();
CTime startTime;
CTimeSpan totalTime;
startTime=CTime::GetCurrentTime();
CMessage::PrintInfo(_T("分词程序初始化,请稍候..."));
if(!g_wordSeg.InitWorgSegment(theApp.m_strPath.GetBuffer(0),m_paramClassifier.m_nLanguageType))
{
CMessage::PrintError(_T("分词程序初始化失败!"));
return false;
}
if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
g_wordSeg.SetSegSetting(CWordSegment::uPlace);
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("分词程序初始化结束,耗时")+totalTime.Format("%H:%M:%S"));
startTime=CTime::GetCurrentTime();
CMessage::PrintInfo(_T("开始扫描训练文档,请稍候..."));
if(m_lstTrainCatalogList.BuildLib(m_paramClassifier.m_txtTrainDir)<=0)
{
CMessage::PrintError("训练文档的总数为0!");
return false;
}
CString strFileName;
POSITION pos_cata=m_lstTrainCatalogList.GetFirstPosition();
int nCount,nCataNum;
nCataNum=m_lstTrainCatalogList.GetCataNum();
while(pos_cata!=NULL)
{
CCatalogNode& catalognode=m_lstTrainCatalogList.GetNext(pos_cata);
POSITION pos_doc=catalognode.GetFirstPosition();
while(pos_doc!=NULL)
{
CDocNode& docnode=catalognode.GetNext(pos_doc);
CMessage::PrintStatusInfo(_T("扫描文档")+docnode.m_strDocName);
if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese)
nCount=docnode.ScanChinese(catalognode.m_strDirName.GetBuffer(0),
m_lstWordList,nCataNum,catalognode.m_idxCata);
else
nCount=docnode.ScanEnglish(catalognode.m_strDirName.GetBuffer(0),
m_lstWordList,nCataNum,catalognode.m_idxCata,
m_paramClassifier.m_bStem);
if(nCount==0)
{
CMessage::PrintError("文件"+catalognode.m_strDirName+"\\"+docnode.m_strDocName+"无内容!");
continue;
}
else if(nCount<0)
{
CMessage::PrintError("文件"+catalognode.m_strDirName+"\\"+docnode.m_strDocName+"无法打开!");
continue;
}
catalognode.m_lTotalWordNum+=nCount;// information collection point
}
}
g_wordSeg.FreeWordSegment();
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("扫描训练文档结束,耗时")+totalTime.Format("%H:%M:%S"));
return true;
}
void CClassifier::InitTrain()
{
m_lstTrainWordList.InitWordList();
m_lstTrainCatalogList.InitCatalogList();
m_lstWordList.InitWordList();
}
//参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器
bool CClassifier::WriteModel(CString strFileName, int nType)
{
CFile fOut;
if( !fOut.Open(strFileName,CFile::modeCreate | CFile::modeWrite) )
{
CMessage::PrintError("无法创建文件"+strFileName+"!");
return false;
}
CArchive ar(&fOut,CArchive::store);
if(nType==0)
{
m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\features.dat");
m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\\features.txt");
m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\train.dat");
m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\\train.txt");
m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\\params.dat");
ar<<dwModelFileID;
ar<<CString("params.dat");
ar<<CString("features.dat");
ar<<CString("train.dat");
}
else
{
m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\features.dat");
m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\\features.txt");
m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\train.dat",1);
m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\\train.txt");
m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\\params.dat");
m_theSVM.com_param.classifier_num=m_lstTrainCatalogList.GetCataNum();
m_theSVM.com_param.trainfile="train.txt";
m_theSVM.com_param.resultpath=m_paramClassifier.m_txtResultDir;
m_theSVM.com_param.DumpToFile(m_paramClassifier.m_txtResultDir+"\\svmparams.dat");
ar<<dwModelFileID;
ar<<CString("params.dat");
ar<<CString("features.dat");
ar<<CString("train.dat");
ar<<CString("svmparams.dat");
}
ar.Close();
fOut.Close();
return true;
}
bool CClassifier::OpenModel(CString strFileName)
{
CFile fIn;
if(!fIn.Open(strFileName,CFile::modeRead))
{
CMessage::PrintError("无法打开文件"+strFileName+"!") ;
return false;
}
CTime startTime=CTime::GetCurrentTime();
CMessage::PrintInfo(_T("正在打开分类模型文件,请稍候..."));
CArchive ar(&fIn,CArchive::load);
CString str,strPath;
DWORD dwFileID;
//读入文件格式标识符
strPath=strFileName.Left(strFileName.ReverseFind('\\'));
ar>>dwFileID;
if(dwFileID!=dwModelFileID)
{
ar.Close();
fIn.Close();
CMessage::PrintError("分类模型文件的格式不正确!");
return false;
}
ar>>str;
if(!m_paramClassifier.GetFromFile(strPath+"\\"+str))
{
CMessage::PrintError(_T("无法打开训练参数文件"+str+"!"));
return false;
}
m_paramClassifier.m_txtResultDir=strPath;
if(m_paramClassifier.m_nClassifierType==0)
{
ar>>str;
m_lstTrainWordList.InitWordList();
if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str))
{
CMessage::PrintError(_T("无法打开特征类表文件"+str+"!"));
return false;
}
ar>>str;
m_lstTrainCatalogList.InitCatalogList();
if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str))
{
CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!"));
return false;
}
}
else
{
ar>>str;
m_lstTrainWordList.InitWordList();
if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str))
{
CMessage::PrintError(_T("无法打开特征类表文件"+str+"!"));
return false;
}
//对于SVM分类起来说m_lstTrainCatalogList其实没用
//此处读入它只是为了在CLeftViw中显示某些统计信息时使用
ar>>str;
m_lstTrainCatalogList.InitCatalogList();
if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str))
{
CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!"));
return false;
}
ar>>str;
if(!m_theSVM.com_param.GetFromFile(strPath+"\\"+str))
{
CMessage::PrintError(_T("无法打开SVM训练参数文件"+str+"!"));
return false;
}
m_theSVM.com_param.trainfile=strPath+"\\train.txt";
m_theSVM.com_param.resultpath=strPath;
}
ar.Close();
fIn.Close();
Prepare();
CTimeSpan totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("分类模型文件已经打开,耗时")+totalTime.Format("%H:%M:%S")+"\r\n");
str.Empty();
m_paramClassifier.GetParamString(str);
CMessage::PrintInfo(str);
return true;
}
bool CClassifier::Classify()
{
m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\\classes.txt");
CTime startTime;
CTimeSpan totalTime;
startTime=CTime::GetCurrentTime();
CMessage::PrintInfo(_T("正在扫描测试文档,请稍候..."));
if(m_paramClassifier.m_bEvaluation)
{
//vBuildLib方法中已经清空了g_lstTestCatalogList,所以此处无需再对其初始化
m_lstTestCatalogList.BuildLib(m_paramClassifier.m_strTestDir);
if(!m_lstTestCatalogList.BuildCatalogID(m_lstTrainCatalogList))
{
CMessage::PrintError("测试文件中包含有无法识别的类别!");
return false;
}
}
else
{
m_lstTestCatalogList.InitCatalogList();
CCatalogNode catalognode;
catalognode.m_strDirName=m_paramClassifier.m_strTestDir;
catalognode.m_strCatalogName="测试文档";
catalognode.m_idxCata=-1;
POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode);
CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp);
cataTemp.SetStartDocID(0);
cataTemp.ScanDirectory(m_paramClassifier.m_strTestDir);
}
if(m_lstTestCatalogList.GetDocNum()<=0)
{
CMessage::PrintError("测试文件总数为0!\r\n如果不需要对分类结果进行评价时,分类文档必须在\"分类文档目录\"下,而不是它的子目录下!");
return false;
}
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("扫描测试文档结束,耗时")+totalTime.Format("%H:%M:%S"));
startTime=CTime::GetCurrentTime();
CMessage::PrintInfo(_T("正在对测试文档进行分类,请稍候..."));
long lCorrect=0,lUnknown=0;
lUnknown=Classify(m_lstTestCatalogList);
lCorrect=SaveResults(m_lstTestCatalogList,m_paramClassifier.m_strResultDir+"\\results.txt");
long lTotalNum=m_lstTestCatalogList.GetDocNum()-lUnknown;
CString str;
totalTime=CTime::GetCurrentTime()-startTime;
CMessage::PrintInfo(_T("测试文档分类结束,耗时")+totalTime.Format("%H:%M:%S"));
if (lUnknown>0)
{
str.Format("无法分类的文档数%d:",lUnknown);
CMessage::PrintInfo(str);
}
if(m_paramClassifier.m_bEvaluation&&lTotalNum>0&&lCorrect>0)
str.Format("测试文档总数:%d,准确率:%f",m_lstTestCatalogList.GetDocNum(),(float)(lCorrect)/(float)(lTotalNum));
else
str.Format("测试文档总数:%d",m_lstTestCatalogList.GetDocNum());
CMessage::PrintInfo(str);
return true;
}
//对Smart格式的文档进行分类
bool CClassifier::ClassifySmart()
{
m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\\classes.txt");
m_lstTestCatalogList.InitCatalogList();
CCatalogNode catalognode;
catalognode.m_strDirName=m_paramClassifier.m_strTestDir;
catalognode.m_strCatalogName="测试文档";
catalognode.m_idxCata=-1;
POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode);
CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp);
FILE *stream1,*stream2;
if( (stream1 = fopen( m_paramClassifier.m_strTestDir, "r" )) == NULL )
{
CMessage::PrintError("无法打开文件"+m_paramClassifier.m_strTestDir+"!");
return false;
}
//如果是SVM分类器,则需要先将所有测试文档转换为向量,保存到文件test.dat
if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM)
{
m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\\test.dat";
if((stream2=fopen(m_theSVM.com_param.classifyfile,"w"))==NULL)
{
CMessage::PrintError("无法创建测试文档向量文件"+m_theSVM.com_param.classifyfile+"!");
return false;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -