⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 knndiscriminanttraining.cpp

📁 线形判别分析算法和knn最近邻算法的实现
💻 CPP
字号:
// KNNDiscriminantTraining.cpp: implementation of the KNNDiscriminantTraining class.
//
//////////////////////////////////////////////////////////////////////

#include "KNNDiscriminantTraining.h"
#include "FeatureSelection.h"
#include <math.h>
//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////

DistanceRecord::DistanceRecord(int i,double j)
{
   dest = i;
   distance = j;
}

int KNNDiscriminantTraining::TrainingProcedure()
{
    return 0;
}

double KNNDiscriminantTraining::TestingProcedure()
{
	int i,h;
	double r;
    //FeatureVectorProcess();
	out.open("haha.txt");
	h=0;
	for(i=0;i<TESTNUM;i++)
	{
	   disRecord0.push_back(vector< DR >());
	   disRecord1.push_back(vector< DR >());
	   disRecord4.push_back(vector< DR >());
	   disRecord5.push_back(vector< DR >());
	   disRecord6.push_back(vector< DR >());
	   disRecord7.push_back(vector< DR >());
	   disRecord8.push_back(vector< DR >());
	   disRecord9.push_back(vector< DR >());
	}
	TestingSingleClass(trainingGroup0,0);
	TestingSingleClass(trainingGroup1,1);
	TestingSingleClass(trainingGroup4,4);
	TestingSingleClass(trainingGroup5,5);
	TestingSingleClass(trainingGroup6,6);
	TestingSingleClass(trainingGroup7,7);
	TestingSingleClass(trainingGroup8,8);
	TestingSingleClass(trainingGroup9,9);

    h += ComputSingleClassRecongValue(disRecord0,0);
    h += ComputSingleClassRecongValue(disRecord1,1);
    h += ComputSingleClassRecongValue(disRecord4,4);
    h += ComputSingleClassRecongValue(disRecord5,5);
    h += ComputSingleClassRecongValue(disRecord6,6);
    h += ComputSingleClassRecongValue(disRecord7,7);
    h += ComputSingleClassRecongValue(disRecord8,8);
    h += ComputSingleClassRecongValue(disRecord9,9);
	r = ((double)h)/(8*TESTNUM);
    out <<"The total recongnize precision: "<<r<<endl;
	return 0;
}

int KNNDiscriminantTraining::ComputSingleClassRecongValue(vector< vector< DR > >& disRecord,int class_sign)
{
	int m=0;
	int s,i;
	double h;
    for(i=0;i<disRecord.size();i++)
	{
	  if((s=FindKnearestNeigh(disRecord[i],KNN))==class_sign)
	  	 m++;
	  out<<s<<" ";
    }
	h = ((double)m)/TESTNUM;
	out<<endl<<"The Recongnize precision of "<<class_sign<<": "<<h<<endl;
	return m;
}

//标准最近邻,考虑最近类别数量
int KNNDiscriminantTraining::FindKnearestNeigh(vector< DR >& samp,int k)
{
    int c[10];
	double m[10];
	int i,j,r,t;
	DR temp(0,0);
    double size = samp.size();
    for(i=0;i<10;i++)
	{
		c[i] = 0;
		m[i] = 0;
	}
	for(i=0;i<k;i++)
	{
	  for(j=0;j<size-i;j++)
	  {
	     if(samp[j].distance<samp[j+1].distance)
		 {
		     temp = samp[j];
			 samp[j] = samp[j+1];
			 samp[j+1] = temp;
		 }
	  }
	}
	for(i=0;i<k;i++)
	{
	   c[samp[size-i-1].dest]++;
	   m[samp[size-i-1].dest] += samp[size-i-1].distance;
	}
	for(i=0;i<10;i++)
	{
	   if(c[i]!=0)
	   {
	      m[i] /=c[i];
	   }
	}
	t = c[0];
	r = 0;
	for(i=1;i<10;i++)
	{
	   if(t<c[i]) 
	   {
		   t = c[i];
		   r = i;
	   }
	   else if(t==c[i])
	   {
	      if(m[r] >= m[i])
		  {
		     r = i;
		  }
	   }
	}
	return r;
}


int KNNDiscriminantTraining::TestingSingleClass(vector < vector<double> >& trainingGroup,
												 int class_sign)
{
   int i,j,r;
   vector <DR>tempRecord;
   double d;
   for(i=0;i<testGroup0.size();i++)
   {
     for(j=0;j<trainingGroup.size();j++)
	 {
	    d = GetEulDistance(testGroup0[i],trainingGroup[j]);
		disRecord0[i].push_back(DR(class_sign,d));
	 }
   }
   for(i=0;i<testGroup1.size();i++)
   {
     for(j=0;j<trainingGroup.size();j++)
	 {
	    d = GetEulDistance(testGroup1[i],trainingGroup[j]);
		disRecord1[i].push_back(DR(class_sign,d));
	 }
   }
   for(i=0;i<testGroup4.size();i++)
   {
     for(j=0;j<trainingGroup.size();j++)
	 {
	    d = GetEulDistance(testGroup4[i],trainingGroup[j]);
		disRecord4[i].push_back(DR(class_sign,d));
	 }
   }
   for(i=0;i<testGroup5.size();i++)
   {
     for(j=0;j<trainingGroup.size();j++)
	 {
	    d = GetEulDistance(testGroup5[i],trainingGroup[j]);
		disRecord5[i].push_back(DR(class_sign,d));
	 }
   }
   for(i=0;i<testGroup6.size();i++)
   {
     for(j=0;j<trainingGroup.size();j++)
	 {
	    d = GetEulDistance(testGroup6[i],trainingGroup[j]);
		disRecord6[i].push_back(DR(class_sign,d));
	 }
   }
   for(i=0;i<testGroup7.size();i++)
   {
     for(j=0;j<trainingGroup.size();j++)
	 {
	    d = GetEulDistance(testGroup7[i],trainingGroup[j]);
		disRecord7[i].push_back(DR(class_sign,d));
	 }
   }
   for(i=0;i<testGroup8.size();i++)
   {
     for(j=0;j<trainingGroup.size();j++)
	 {
	    d = GetEulDistance(testGroup8[i],trainingGroup[j]);
		disRecord8[i].push_back(DR(class_sign,d));
	 }
   }
   for(i=0;i<testGroup9.size();i++)
   {
     for(j=0;j<trainingGroup.size();j++)
	 {
	    d = GetEulDistance(testGroup9[i],trainingGroup[j]);
		disRecord9[i].push_back(DR(class_sign,d));
	 }
   }
   return 0;
}

double KNNDiscriminantTraining::GetEulDistance(const vector <double>& s,const vector <double>& d)
{
   int i;
   double h=0;
   for(i=1;i<s.size();i++)
   {
     h += (s[i]-d[i])*(s[i]-d[i]);
   }
   return sqrt(h);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -