⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classificationrda.cpp

📁 该程序包实现了几个常用的模式识别分类器算法
💻 CPP
字号:
/* CRda类实现规则判别分析算法(RDA),其中“paramBeta”和“paramGamma”是算法中的两个超参数。
   使用方法如下:创建一个对象,然后设置超参数,调用train方法训练分类器,调用test方法进行分类。 */
#include "stdlib.h"
#include <iostream>
#include <string.h>
#include <fstream>
#include <math.h>
using namespace std;
#include "global.h"

CRda::CRda()
{
	paramU = NULL;
	paramSigma = NULL;
	paramInvSigma = NULL;
	paramLogSigma = NULL;
	resultTotal = NULL;
	resultRight = NULL;
	paramBeta = 0;
	paramGamma = 0;
}

CRda::~CRda()
{
	if (this->paramU != NULL)
		delete[] this->paramU;
	if (this->paramSigma != NULL)
		delete[] this->paramSigma;
	if (this->paramInvSigma != NULL)
		delete[] this->paramInvSigma;
	if (this->paramLogSigma != NULL)
		delete[] this->paramLogSigma;
	if (this->resultTotal != NULL)
		delete[] this->resultTotal;
	if (this->resultRight != NULL)
		delete[] this->resultRight;
}

/* 训练分类器,输入训练样本的文件名,返回true表示训练成功,返回false表示训练失败。 */
bool CRda::train(char* fileName)
{
	bool ret = true;
	int i, j, k, index, cnum;
	ret = this->sdata.readFile(fileName);//读入训练数据
	if (!ret)
		return false;//如果数据格式不正确,退出程序。
	this->sdata.normalize();
	int numClass = this->sdata.numClass;
	int numFeature = this->sdata.numFeature;
	int numSample = this->sdata.numSample;
	if (this->paramU == NULL)
		this->paramU = new DOUBLE[numClass*numFeature];
	if (this->paramSigma == NULL)
		this->paramSigma = new DOUBLE[numClass*numFeature*numFeature];
	if (this->paramInvSigma == NULL)
		this->paramInvSigma = new DOUBLE[numClass*numFeature*numFeature];
	if (this->paramLogSigma == NULL)
		this->paramLogSigma = new DOUBLE[numClass];
	DOUBLE *sumX = new DOUBLE[numClass*numFeature];
	DOUBLE *sumXX = new DOUBLE[numClass*numFeature*numFeature];
	int* classSample = new int[numClass];
	DOUBLE *originalSigma = new DOUBLE[numClass*numFeature*numFeature];
	DOUBLE *sigma0 = new DOUBLE[numFeature*numFeature];
	DOUBLE *sigmaI = new DOUBLE[numClass];
	int total;
	for (i=0; i<numClass; i++)
	{
		classSample[i]=0;
	}
	total = numClass * numFeature;
	for (i=0; i<total; i++)
	{
		sumX[i]=0;
	}
	total = numClass * numFeature * numFeature;
	for (i=0; i<total; i++)
	{
		sumXX[i]=0;
	}

	for (i=0; i<numSample; i++)
	{
		cnum = this->sdata.ydata[i];
		classSample[cnum] ++;
		for (j=0; j<numFeature; j++)
		{
			sumX[cnum*numFeature+j] += this->sdata.xdata[i*numFeature+j];
			for (k=j; k<numFeature; k++)
			{
				sumXX[(cnum*numFeature+j)*numFeature+k] +=this->sdata.xdata[i*numFeature+j] * this->sdata.xdata[i*numFeature+k];
			}
		}
	}
	for (i=0; i<numClass; i++)
	{
		for (j=0; j<numFeature; j++)
		{
			this->paramU[i*numFeature+j] = sumX[i*numFeature+j] / classSample[i];
		}
	}
	for (i=0; i<numClass; i++)
	{
		for (j=0; j<numFeature; j++)
		{
			for (k=j; k<numFeature; k++)
			{
				index = (i*numFeature+j)*numFeature+k;
				originalSigma[index] = sumXX[index]/classSample[i]-(this->paramU[i*numFeature+j]*this->paramU[i*numFeature+k]);
			}
		}
	}
	for (i=0; i<numClass; i++)
	{
		for (j=1; j<numFeature; j++)
		{
			for (k=0; k<j; k++)
			{
				originalSigma[(i*numFeature+j)*numFeature+k] = originalSigma[(i*numFeature+k)*numFeature+j];
			}
		}
	}
	DOUBLE temp1 = (1-this->paramBeta)*(1-this->paramGamma);
	DOUBLE temp2 = this->paramBeta*(1-this->paramGamma);
	int totalNum = this->sdata.numSample;
	for (j=0; j<numFeature; j++)
	{
		for (k=0; k<numFeature; k++)
		{
			index = j*numFeature+k;
			sigma0[index] = 0;
			for (i=0; i<numClass; i++)
			{
				sigma0[index] += originalSigma[(i*numFeature+j)*numFeature+k] * classSample[i];
			}
			sigma0[index] = sigma0[index]/totalNum;
		}
	}
	for (i=0; i<numClass; i++)
	{
		sigmaI[i] = 0;
		for (j=0; j<numFeature; j++)
		{
			sigmaI[i] += originalSigma[j*numFeature+j];
		}
		sigmaI[i] = sigmaI[i]/numFeature;
	}
	for (i=0; i<numClass; i++)
	{
		for (j=0; j<numFeature; j++)
		{
			for (k=0; k<numFeature; k++)
			{
				index = (i*numFeature+j)*numFeature+k;
				this->paramSigma[index] = temp1*originalSigma[index]+temp2*sigma0[j*numFeature+k];
			}
		}
	}
	for (i=0; i<numClass; i++)
	{
		temp2 = this->paramGamma*sigmaI[i];
		for (j=0; j<numFeature; j++)
		{
			index = (i*numFeature+j)*numFeature+j;
			this->paramSigma[index] += temp2;
			if (this->paramSigma[index]<0.001)//限制协方差矩阵的对角元素的下限
			{
				this->paramSigma[index] = 0.001;
			}
		}
	}
	//模型参数预处理
	CMatrix sigma;
	CMatrix invSigma;
	int size = numFeature * numFeature;
	int totalSample = this->sdata.numSample;
	sigma.nline = numFeature;
	sigma.ncol = numFeature;
	sigma.pdata = new DOUBLE[size];
	for (i=0; i<numClass; i++)
	{
		for (j=0; j<size; j++)
		{
			sigma.pdata[j] = this->paramSigma[i*size+j];
		}
		sigma.inv(invSigma);
		for (j=0; j<size; j++)
		{
			this->paramInvSigma[i*size+j] = invSigma.pdata[j];
		}
		this->paramLogSigma[i] = log(sigma.det())-2*log(((DOUBLE)classSample[i])/((DOUBLE)totalSample));//这里log表示自然对数
	}
	delete[] sumX;
	delete[] sumXX;
	delete[] classSample;
	delete[] originalSigma;
	delete[] sigma0;
	delete[] sigmaI;
	return ret;
}

/* 使用分类器进行分类,输入测试样本的文件名称,返回true表示程序执行正常,false表示程序执行错误。 */
bool CRda::test(char* fileName)
{
	bool ret = true;
	ifstream ifs(fileName);
	int bufSize = 1024;
	char* buf=new char[bufSize];
	char* temp=new char[bufSize];
	int numClass = this->sdata.numClass;
	int numFeature = this->sdata.numFeature;
	DOUBLE* sample=new DOUBLE[numFeature];
	DOUBLE* value=new DOUBLE[numClass];
	int n=0;
	int i, j, k;
	int cnum;
	DOUBLE max;
	int index;

	if (NULL==this->resultTotal)
		this->resultTotal = new DOUBLE[numClass];
	if (NULL==this->resultRight)
		this->resultRight = new DOUBLE[numClass];
	for (i=0; i<numClass; i++)
	{
		this->resultTotal[i] = 0;
		this->resultRight[i] = 0;
	}
	while (ifs.good()) {
		ifs.getline(buf, bufSize);
		n=0;
		while (*(buf+n)!=',' && n<bufSize)
		{
			*(temp+n)=*(buf+n);
			n++;
		}
		*(temp+n)=0;
		cnum=this->sdata.searchClassName(temp);
		if (cnum>-1)
		{
			this->resultTotal[cnum] += 1;
			for (j=0; j<numFeature; j++)
			{
				n++;
				int n2=n;
				while (*(buf+n)!=',' && n<bufSize)
				{
					*(temp+n-n2)=*(buf+n);
					n++;
				}
				*(temp+n-n2)=0;
				*(sample+j)=atof(temp);
				if (this->sdata.maxValue[j] > this->sdata.minValue[j])
				{
					*(sample+j) = (*(sample+j) - this->sdata.minValue[j])/(this->sdata.maxValue[j] - this->sdata.minValue[j]);
				}
				else
					*(sample+j) = 0.5;
			}
			for (i=0; i<numClass; i++)
			{
				*(value+i) = -this->paramLogSigma[i];
				for (j=0; j<numFeature; j++)
				{
					for (k=0; k<numFeature; k++)
					{
						*(value+i) -= (*(sample+j) - this->paramU[i*numFeature+j])*(this->paramInvSigma[(i*numFeature+j)*numFeature+k])*(*(sample+k) - this->paramU[i*numFeature+k]);
					}
				}
			}
			max = *value;
			index = 0;
			for (i=1; i<numClass; i++)
			{
				if (*(value+i)>max)
				{
					max = *(value+i);
					index = i;
				}
			}
			if (index == cnum)
				this->resultRight[cnum] += 1;
		}
	}
	delete[] buf;
	delete[] temp;
	delete[] sample;
	delete[] value;
	return ret;
}

DOUBLE CRda::crossValidation(char* fileName)
{
	int fold = 4;
	int i, j, k, index, cnum;
	int n1, n2;
	if (this->sdata.ydata == NULL)
	{
		if(!this->sdata.readFile(fileName))//读入训练数据
			return 0;//如果数据格式不正确,退出程序。
		this->sdata.normalize();
	}
	int numClass = this->sdata.numClass;
	int numFeature = this->sdata.numFeature;
	int numTotalSample = this->sdata.numSample;
	int numSample = this->sdata.numSample/fold;
	if (this->paramU == NULL)
		this->paramU = new DOUBLE[numClass*numFeature];
	if (this->paramSigma == NULL)
		this->paramSigma = new DOUBLE[numClass*numFeature*numFeature];
	if (this->paramInvSigma == NULL)
		this->paramInvSigma = new DOUBLE[numClass*numFeature*numFeature];
	if (this->paramLogSigma == NULL)
		this->paramLogSigma = new DOUBLE[numClass];
	DOUBLE *sumX = new DOUBLE[numClass*numFeature];
	DOUBLE *sumXX = new DOUBLE[numClass*numFeature*numFeature];
	int* classSample = new int[numClass];
	DOUBLE *originalSigma = new DOUBLE[numClass*numFeature*numFeature];
	DOUBLE *sigma0 = new DOUBLE[numFeature*numFeature];
	DOUBLE *sigmaI = new DOUBLE[numClass];
	DOUBLE* sample=new DOUBLE[numFeature];
	DOUBLE* value=new DOUBLE[numClass];
	DOUBLE temp, max;
	int total;
	int resultRight = 0;
	int resultTotal = fold*numSample;
	for (n1=0; n1<fold; n1++)
	{
		for (i=0; i<numClass; i++)
		{
			classSample[i]=0;
		}
		total = numClass * numFeature;
		for (i=0; i<total; i++)
		{
			sumX[i]=0;
		}
		total = numClass * numFeature * numFeature;
		for (i=0; i<total; i++)
		{
			sumXX[i]=0;
		}
		for (n2=0; n2<fold; n2++)
		{
			if (n2!=n1)
			{
				for (i=0; i<numSample; i++)
				{
					cnum = this->sdata.ydata[n2*numSample+i];
					classSample[cnum] ++;
					for (j=0; j<numFeature; j++)
					{
						temp = this->sdata.xdata[(n2*numSample+i)*numFeature+j];
						sumX[cnum*numFeature+j] += temp;
						for (k=j; k<numFeature; k++)
						{
							sumXX[(cnum*numFeature+j)*numFeature+k] +=temp * this->sdata.xdata[(n2*numSample+i)*numFeature+k];
						}
					}
				}
			}
		}
		for (i=0; i<numClass; i++)
		{
			for (j=0; j<numFeature; j++)
			{
				this->paramU[i*numFeature+j] = sumX[i*numFeature+j] / classSample[i];
			}
		}
		for (i=0; i<numClass; i++)
		{
			for (j=0; j<numFeature; j++)
			{
				for (k=j; k<numFeature; k++)
				{
					index = (i*numFeature+j)*numFeature+k;
					originalSigma[index] = sumXX[index]/classSample[i]-(this->paramU[i*numFeature+j]*this->paramU[i*numFeature+k]);
				}
			}
		}
		for (i=0; i<numClass; i++)
		{
			for (j=1; j<numFeature; j++)
			{
				for (k=0; k<j; k++)
				{
					originalSigma[(i*numFeature+j)*numFeature+k] = originalSigma[(i*numFeature+k)*numFeature+j];
				}
			}
		}
		DOUBLE temp1 = (1-this->paramBeta)*(1-this->paramGamma);
		DOUBLE temp2 = this->paramBeta*(1-this->paramGamma);
		int totalNum = numSample*(fold-1);
		for (j=0; j<numFeature; j++)
		{
			for (k=0; k<numFeature; k++)
			{
				index = j*numFeature+k;
				sigma0[index] = 0;
				for (i=0; i<numClass; i++)
				{
					sigma0[index] += originalSigma[(i*numFeature+j)*numFeature+k] * classSample[i];
				}
				sigma0[index] = sigma0[index]/totalNum;
			}
		}
		for (i=0; i<numClass; i++)
		{
			sigmaI[i] = 0;
			for (j=0; j<numFeature; j++)
			{
				sigmaI[i] += originalSigma[j*numFeature+j];
			}
			sigmaI[i] = sigmaI[i]/numFeature;
		}
		for (i=0; i<numClass; i++)
		{
			for (j=0; j<numFeature; j++)
			{
				for (k=0; k<numFeature; k++)
				{
					index = (i*numFeature+j)*numFeature+k;
					this->paramSigma[index] = temp1*originalSigma[index]+temp2*sigma0[j*numFeature+k];
				}
			}
		}
		for (i=0; i<numClass; i++)
		{
			temp2 = this->paramGamma*sigmaI[i];
			for (j=0; j<numFeature; j++)
			{
				index = (i*numFeature+j)*numFeature+j;
				this->paramSigma[index] += temp2;
				if (this->paramSigma[index]<0.001)//限制协方差矩阵的对角元素的下限
				{
					this->paramSigma[index] = 0.001;
				}
			}
		}
		//模型参数预处理
		CMatrix sigma;
		CMatrix invSigma;
		int size = numFeature * numFeature;
		sigma.nline = numFeature;
		sigma.ncol = numFeature;
		sigma.pdata = new DOUBLE[size];
		for (i=0; i<numClass; i++)
		{
			for (j=0; j<size; j++)
			{
				sigma.pdata[j] = this->paramSigma[i*size+j];
			}
			sigma.inv(invSigma);
			for (j=0; j<size; j++)
			{
				this->paramInvSigma[i*size+j] = invSigma.pdata[j];
			}
			this->paramLogSigma[i] = log(sigma.det())-2*log(((DOUBLE)classSample[i])/((DOUBLE)(numSample*(fold-1))));//这里log表示自然对数
		}
		//开始测试
		for (n2=0; n2<numSample; n2++)
		{
			cnum = this->sdata.ydata[n1*numSample+n2];
			for (j=0; j<numFeature; j++)
			{
				*(sample+j)=this->sdata.xdata[(n1*numSample+n2)*numFeature+j];
			}
			for (i=0; i<numClass; i++)
			{
				*(value+i) = -this->paramLogSigma[i];
				for (j=0; j<numFeature; j++)
				{
					for (k=0; k<numFeature; k++)
					{
						*(value+i) -= (*(sample+j) - this->paramU[i*numFeature+j])*(this->paramInvSigma[(i*numFeature+j)*numFeature+k])*(*(sample+k) - this->paramU[i*numFeature+k]);
					}
				}
			}
			max = *value;
			index = 0;
			for (i=1; i<numClass; i++)
			{
				if (*(value+i)>max)
				{
					max = *(value+i);
					index = i;
				}
			}
			if (index == cnum)
				resultRight ++;
		}
	}
	DOUBLE ret = ((DOUBLE)resultRight)/((DOUBLE)resultTotal);
	delete[] sumX;
	delete[] sumXX;
	delete[] classSample;
	delete[] originalSigma;
	delete[] sigma0;
	delete[] sigmaI;
	delete[] sample;
	delete[] value;
	return ret;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -