📄 classifier.h
字号:
// Classifier.h: interface for the CClassifier class.
//
//////////////////////////////////////////////////////////////////////
#if !defined(AFX_CLASSIFIER_H__FA4DB8D8_AC36_44A8_884B_0D715575B7A1__INCLUDED_)
#define AFX_CLASSIFIER_H__FA4DB8D8_AC36_44A8_884B_0D715575B7A1__INCLUDED_
#if _MSC_VER > 1000
#pragma once
#endif // _MSC_VER > 1000
#include "CatalogList.h"
#include "ClassifierParam.h"
#include "WordList.h"
#include "Compute_Param.h"
#include "Compute_Prompt.h"
#include "Compute_Result.h"
#include "svm.h"
struct sSortType
{
char word[40];
double dWeight;
CWordNode *pclsWordNode;
};
struct DocWeight
{
double dWeight;
long lDocID;
};
struct DocCatalog
{
CDocNode * pDocNode;
short nCataID;
};
class CClassifier
{
public:
CClassifier();
virtual ~CClassifier();
public: //训练时需要用到的公有成员方法
//参数bGenDic=false 层次分类时使用, nType决定分类模型的类别
bool Train(int nType=0,bool bFlag=true);
void TrainSVM();
bool GenDic();
void InitTrain();
bool WriteModel(CString strFileName, int nType=0);
void Evaluate(CString strPath);
void CopyFile(char *pFileName, char *pSource, char *pTarget, char *pCatalog);
long SaveResults(CCatalogList &cataList, CString strFileName, CStringArray *aryType=NULL);
private: //训练时需要用到的私有成员方法
void Sort(struct sSortType *,int);
void QuickSort(struct sSortType *,int,int);
void GenSortBuf(CWordList& wordList,sSortType *psSortBuf,int nCatalog);
void GenModel();
void FeatherSelection(CWordList& dstWordList);
void FeatherWeight(CWordList& wordList);
public: //分类时需要用到的公有成员方法
bool OpenModel(CString strFileName);
void Prepare();
bool Classify();
bool ClassifySmart();
bool KNNClassify(char *, CDocNode &, bool bFile=true, int nCmpType=0);
long KNNClassify(CCatalogList&,int nCmpType=0);
short KNNCategory(char *,CDocNode &,bool bFile=true, int nCmpType=0);
short KNNCategory(char *pPath, bool bFile=true, int nCmpType=0);
long SVMClassify(CCatalogList& cataList);
bool SVMClassify(char *pPath, CDocNode &docNode, bool bFile=true);
short SVMCategory(char *pPath, CDocNode &docNode, bool bFile=true);
short SVMCategory(char *file, bool bFile=true);
void SVMClassifyVectorFile(CString strFileName);
short GetCategory(char *file, bool bFile=true);
short GetCategory(char *path, CDocNode &docNode, bool bFile=true);
bool Classify(char *path, CDocNode &docNode, bool bFile=true);
long Classify(CCatalogList& cataList);
short SingleCategory(CDocNode &docNode);
bool MultiCategory(CDocNode &docNode, CArray<short,short>& aryResult, double dThreshold);
private: //分类时需要用到的私有成员方法
void ComputeWeight(bool bMult=false);
void ComputeSimRatio(CDocNode &, int nCmpType=0);
void Sort(DocWeight *,int);
void QuickSort(DocWeight*, int, int);
private: //分类时需要用到的私有成员变量
//指向训练文本集的指针,用来加快读取速度
DocCatalog *m_pDocs;
//暂时保存当前测试文档和训练文档中每一篇文档的相似度
DocWeight *m_pSimilarityRatio;
//训练文档的个数
long m_lDocNum;
//训练文档的类别数
short m_nClassNum;
public:
//分类模型文件头标识符
static const DWORD dwModelFileID;
CClassifierParam m_paramClassifier;
CWordList m_lstTrainWordList;
//训练时需要用到的,用来保存在没有进行特征选择之前训练集中所有的特征
CWordList m_lstWordList;
CCatalogList m_lstTrainCatalogList;
CCatalogList m_lstTestCatalogList;
CSVM m_theSVM;
};
extern CClassifier theClassifier;
#endif // !defined(AFX_CLASSIFIER_H__FA4DB8D8_AC36_44A8_884B_0D715575B7A1__INCLUDED_)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -