simpledistanceclassifier.cpp

来自「简单分类器 VC++ 用于简单的分类」· C++ 代码 · 共 82 行

CPP
82
字号
#include "StdAfx.h"
#include ".\simpledistanceclassifier.h"

CSimpleDistanceClassifier::CSimpleDistanceClassifier(void)
{
}

CSimpleDistanceClassifier::~CSimpleDistanceClassifier(void)
{
}

bool CSimpleDistanceClassifier::Train(CString& sTrainDataFile, CString& sResultFile)
{
	//加载训练数据
	CDataSet trainSet;
	if (!trainSet.Load(sTrainDataFile))
		return false;

	//进行统计分析
	//简单距离法:仅仅计算各分类的中心(同一类型数据的平均)--保存在一个数据集中
	int nClasses = trainSet.Classes();
	int *classDataCount = new int [nClasses] ;
	memset(classDataCount, 0, sizeof(int)*nClasses);
	m_ClassCenters.SetSize(nClasses);

	//累加各类数据的各维数据
	for (int i=0; i<trainSet.Count(); i++)
	{
		CDataItem& data = trainSet[i];
		int clsId = data.ClassID();
		if (m_ClassCenters[clsId].Dimension() == 0)
			m_ClassCenters[clsId] = data;
		else
			m_ClassCenters[clsId] += data;
		CString s = m_ClassCenters[clsId].ToString();
		classDataCount[clsId] ++;
	}
	//求平均
	for(i=0; i<nClasses; i++)
	{
		if (classDataCount[i] != 0)
		{
			m_ClassCenters[i] /= classDataCount[i];
		}
	}
	delete [] classDataCount;

	//保存分类中心
	m_ClassCenters.Save(sResultFile);

	return true;
}

bool CSimpleDistanceClassifier::Load(CString& sClassifierDataFile)
{
	//加载训练得到的分类中心
	m_ClassCenters.Load(sClassifierDataFile);
	return true;
}

//识别一个新数据的分类
int CSimpleDistanceClassifier::Recognize(CDataItem& item)
{

	double minDistance;
	int nClassId = 0;
	for(int i=0; i<m_ClassCenters.Count(); i++)
	{
		//计算与各分类中心的距离,并选择最小距离的分类为数据的最终类别
		double dDistance = item.DistanceFrom(m_ClassCenters[i]);
		if (i==0)
			minDistance = dDistance;
		else if(dDistance < minDistance)
		{
			minDistance = dDistance;
			nClassId = i;
		}
	}
	return nClassId;
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?