📄 simpledistanceclassifier.cpp
字号:
#include "StdAfx.h"
#include ".\simpledistanceclassifier.h"
CSimpleDistanceClassifier::CSimpleDistanceClassifier(void)
{
}
CSimpleDistanceClassifier::~CSimpleDistanceClassifier(void)
{
}
bool CSimpleDistanceClassifier::Train(CString& sTrainDataFile, CString& sResultFile)
{
//加载训练数据
CDataSet trainSet;
if (!trainSet.Load(sTrainDataFile))
return false;
//进行统计分析
//简单距离法:仅仅计算各分类的中心(同一类型数据的平均)--保存在一个数据集中
int nClasses = trainSet.Classes();
int *classDataCount = new int [nClasses] ;
memset(classDataCount, 0, sizeof(int)*nClasses);
m_ClassCenters.SetSize(nClasses);
//累加各类数据的各维数据
for (int i=0; i<trainSet.Count(); i++)
{
CDataItem& data = trainSet[i];
int clsId = data.ClassID();
if (m_ClassCenters[clsId].Dimension() == 0)
m_ClassCenters[clsId] = data;
else
m_ClassCenters[clsId] += data;
CString s = m_ClassCenters[clsId].ToString();
classDataCount[clsId] ++;
}
//求平均
for(i=0; i<nClasses; i++)
{
if (classDataCount[i] != 0)
{
m_ClassCenters[i] /= classDataCount[i];
}
}
delete [] classDataCount;
//保存分类中心
m_ClassCenters.Save(sResultFile);
return true;
}
bool CSimpleDistanceClassifier::Load(CString& sClassifierDataFile)
{
//加载训练得到的分类中心
m_ClassCenters.Load(sClassifierDataFile);
return true;
}
//识别一个新数据的分类
int CSimpleDistanceClassifier::Recognize(CDataItem& item)
{
double minDistance;
int nClassId = 0;
for(int i=0; i<m_ClassCenters.Count(); i++)
{
//计算与各分类中心的距离,并选择最小距离的分类为数据的最终类别
double dDistance = item.DistanceFrom(m_ClassCenters[i]);
if (i==0)
minDistance = dDistance;
else if(dDistance < minDistance)
{
minDistance = dDistance;
nClassId = i;
}
}
return nClassId;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -