📄 knn.cpp
字号:
#include "stdafx.h"
#pragma warning( disable : 4786 )
#include "windows.h"
#include "knn.h"
#include "math.h"
#include "assert.h"
CKNN::CKNN()
{
docs_train = NULL;
vTargetTrain.clear();
k = 1;
lTrainTotal = 0;
}
CKNN::~CKNN()
{
if (docs_train != NULL)
free(docs_train);
vTargetTrain.clear();
}
void CKNN::Train( CDocList& DocList )
{
size_t lDocTotal = DocList.vSDoc.size();
docs_train = (DOC *)my_malloc(sizeof(DOC) * lDocTotal );
vTargetTrain.resize( lDocTotal );
lTrainTotal = 0;
for (unsigned long l=0; l<lDocTotal; l++)
{
long lPos = DocList.mapDocId_Pos[ DocList.vSDoc[l].lDocId ];
docs_train[l] = DocList.docs[lPos];
vTargetTrain[l] = DocList.vSDoc[l].setDocCat;
lTrainTotal++;
}
}
struct K_NODE {
//int iIndex;
float fSim;
set<int> setDocCat;
};
struct CKNode_More_Than_by_Sim : public binary_function<K_NODE, K_NODE, BOOL>
{
bool operator()(K_NODE x, K_NODE y) { return x.fSim > y.fSim; }
};
int CKNN::ClassifyDocs( DOC& test_doc )
{
assert((k>0) && (lTrainTotal>0));
vector<K_NODE> vecKNode;
float fSim;
K_NODE knode;
for (long lTrainIndex = 0;lTrainIndex<lTrainTotal;lTrainIndex++)
{
fSim = CalcSim(docs_train[lTrainIndex].content, test_doc.content);
knode.fSim = fSim;
knode.setDocCat = vTargetTrain[lTrainIndex];
//knode.iIndex = lTrainIndex;
vecKNode.push_back(knode);
sort(vecKNode.begin(),vecKNode.end(),CKNode_More_Than_by_Sim());
if (vecKNode.size()>k)
vecKNode.erase(vecKNode.begin() + vecKNode.size()-1);
}
//for(int i=0;i<vecKNode.size();i++)
//{
// cout << vecKNode[i].iCat << ":" << vecKNode[i].fSim << " ";
//}
//cout << endl;
map<int,float> mapCat2Score;
map<int,float>::iterator it;
for (unsigned int i=0;i<vecKNode.size();i++)
{
set<int> setCat = vecKNode[i].setDocCat;
set<int>::iterator itSet;
for ( itSet=setCat.begin(); itSet!=setCat.end(); itSet++ )
{
int iCurCat = *itSet;
if (mapCat2Score.find(iCurCat)==mapCat2Score.end())
mapCat2Score[iCurCat] = 1;
else
mapCat2Score[iCurCat] ++;
}
}
int iTrueCat = -1;
float fScore = 0.0;
for (it = mapCat2Score.begin();it!=mapCat2Score.end();it++)
{
if (iTrueCat==-1)
iTrueCat = it->first;
if (it->second>fScore) {
fScore = it->second;
iTrueCat = it->first;
}
}
mapCat2Score.clear();
vecKNode.clear();
return iTrueCat;
}
void CKNN::ClassifyDocs( string sVectorFile, string sResultFile )
{
ofstream ofResult( sResultFile.c_str() );
ifstream ifVector( sVectorFile.c_str() );
string sLine;
int iCount =0;
while ( getline( ifVector, sLine )) {
DOC test_doc;
int iParseResult = ReadDoc( sLine, test_doc );
if ( iParseResult>0 ) {
int iCat = ClassifyDocs( test_doc );
ofResult << test_doc.DocId << "\t" << iCat << endl;
}
else
{
ofResult << test_doc.DocId << endl;
continue;
}
free( test_doc.content );
iCount++;
cout << "documents classified: " << iCount << "\r";
}
ifVector.close();
ofResult.close();
}
float CKNN::CalcSim(WORDITEM *p, WORDITEM *q)
{
float f;
f = sprod_ss(p,q);
f /= sqrt(sprod_ss(p,p));
f /= sqrt(sprod_ss(q,q));
return f;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -