📄 classificationrda.cpp
字号:
/* CRda类实现规则判别分析算法(RDA),其中“paramBeta”和“paramGamma”是算法中的两个超参数。
使用方法如下:创建一个对象,然后设置超参数,调用train方法训练分类器,调用test方法进行分类。 */
#include "stdlib.h"
#include <iostream>
#include <string.h>
#include <fstream>
#include <math.h>
using namespace std;
#include "global.h"
CRda::CRda()
{
paramU = NULL;
paramSigma = NULL;
paramInvSigma = NULL;
paramLogSigma = NULL;
resultTotal = NULL;
resultRight = NULL;
paramBeta = 0;
paramGamma = 0;
}
CRda::~CRda()
{
if (this->paramU != NULL)
delete[] this->paramU;
if (this->paramSigma != NULL)
delete[] this->paramSigma;
if (this->paramInvSigma != NULL)
delete[] this->paramInvSigma;
if (this->paramLogSigma != NULL)
delete[] this->paramLogSigma;
if (this->resultTotal != NULL)
delete[] this->resultTotal;
if (this->resultRight != NULL)
delete[] this->resultRight;
}
/* 训练分类器,输入训练样本的文件名,返回true表示训练成功,返回false表示训练失败。 */
bool CRda::train(char* fileName)
{
bool ret = true;
int i, j, k, index, cnum;
ret = this->sdata.readFile(fileName);//读入训练数据
if (!ret)
return false;//如果数据格式不正确,退出程序。
this->sdata.normalize();
int numClass = this->sdata.numClass;
int numFeature = this->sdata.numFeature;
int numSample = this->sdata.numSample;
if (this->paramU == NULL)
this->paramU = new DOUBLE[numClass*numFeature];
if (this->paramSigma == NULL)
this->paramSigma = new DOUBLE[numClass*numFeature*numFeature];
if (this->paramInvSigma == NULL)
this->paramInvSigma = new DOUBLE[numClass*numFeature*numFeature];
if (this->paramLogSigma == NULL)
this->paramLogSigma = new DOUBLE[numClass];
DOUBLE *sumX = new DOUBLE[numClass*numFeature];
DOUBLE *sumXX = new DOUBLE[numClass*numFeature*numFeature];
int* classSample = new int[numClass];
DOUBLE *originalSigma = new DOUBLE[numClass*numFeature*numFeature];
DOUBLE *sigma0 = new DOUBLE[numFeature*numFeature];
DOUBLE *sigmaI = new DOUBLE[numClass];
int total;
for (i=0; i<numClass; i++)
{
classSample[i]=0;
}
total = numClass * numFeature;
for (i=0; i<total; i++)
{
sumX[i]=0;
}
total = numClass * numFeature * numFeature;
for (i=0; i<total; i++)
{
sumXX[i]=0;
}
for (i=0; i<numSample; i++)
{
cnum = this->sdata.ydata[i];
classSample[cnum] ++;
for (j=0; j<numFeature; j++)
{
sumX[cnum*numFeature+j] += this->sdata.xdata[i*numFeature+j];
for (k=j; k<numFeature; k++)
{
sumXX[(cnum*numFeature+j)*numFeature+k] +=this->sdata.xdata[i*numFeature+j] * this->sdata.xdata[i*numFeature+k];
}
}
}
for (i=0; i<numClass; i++)
{
for (j=0; j<numFeature; j++)
{
this->paramU[i*numFeature+j] = sumX[i*numFeature+j] / classSample[i];
}
}
for (i=0; i<numClass; i++)
{
for (j=0; j<numFeature; j++)
{
for (k=j; k<numFeature; k++)
{
index = (i*numFeature+j)*numFeature+k;
originalSigma[index] = sumXX[index]/classSample[i]-(this->paramU[i*numFeature+j]*this->paramU[i*numFeature+k]);
}
}
}
for (i=0; i<numClass; i++)
{
for (j=1; j<numFeature; j++)
{
for (k=0; k<j; k++)
{
originalSigma[(i*numFeature+j)*numFeature+k] = originalSigma[(i*numFeature+k)*numFeature+j];
}
}
}
DOUBLE temp1 = (1-this->paramBeta)*(1-this->paramGamma);
DOUBLE temp2 = this->paramBeta*(1-this->paramGamma);
int totalNum = this->sdata.numSample;
for (j=0; j<numFeature; j++)
{
for (k=0; k<numFeature; k++)
{
index = j*numFeature+k;
sigma0[index] = 0;
for (i=0; i<numClass; i++)
{
sigma0[index] += originalSigma[(i*numFeature+j)*numFeature+k] * classSample[i];
}
sigma0[index] = sigma0[index]/totalNum;
}
}
for (i=0; i<numClass; i++)
{
sigmaI[i] = 0;
for (j=0; j<numFeature; j++)
{
sigmaI[i] += originalSigma[j*numFeature+j];
}
sigmaI[i] = sigmaI[i]/numFeature;
}
for (i=0; i<numClass; i++)
{
for (j=0; j<numFeature; j++)
{
for (k=0; k<numFeature; k++)
{
index = (i*numFeature+j)*numFeature+k;
this->paramSigma[index] = temp1*originalSigma[index]+temp2*sigma0[j*numFeature+k];
}
}
}
for (i=0; i<numClass; i++)
{
temp2 = this->paramGamma*sigmaI[i];
for (j=0; j<numFeature; j++)
{
index = (i*numFeature+j)*numFeature+j;
this->paramSigma[index] += temp2;
if (this->paramSigma[index]<0.001)//限制协方差矩阵的对角元素的下限
{
this->paramSigma[index] = 0.001;
}
}
}
//模型参数预处理
CMatrix sigma;
CMatrix invSigma;
int size = numFeature * numFeature;
int totalSample = this->sdata.numSample;
sigma.nline = numFeature;
sigma.ncol = numFeature;
sigma.pdata = new DOUBLE[size];
for (i=0; i<numClass; i++)
{
for (j=0; j<size; j++)
{
sigma.pdata[j] = this->paramSigma[i*size+j];
}
sigma.inv(invSigma);
for (j=0; j<size; j++)
{
this->paramInvSigma[i*size+j] = invSigma.pdata[j];
}
this->paramLogSigma[i] = log(sigma.det())-2*log(((DOUBLE)classSample[i])/((DOUBLE)totalSample));//这里log表示自然对数
}
delete[] sumX;
delete[] sumXX;
delete[] classSample;
delete[] originalSigma;
delete[] sigma0;
delete[] sigmaI;
return ret;
}
/* 使用分类器进行分类,输入测试样本的文件名称,返回true表示程序执行正常,false表示程序执行错误。 */
bool CRda::test(char* fileName)
{
bool ret = true;
ifstream ifs(fileName);
int bufSize = 1024;
char* buf=new char[bufSize];
char* temp=new char[bufSize];
int numClass = this->sdata.numClass;
int numFeature = this->sdata.numFeature;
DOUBLE* sample=new DOUBLE[numFeature];
DOUBLE* value=new DOUBLE[numClass];
int n=0;
int i, j, k;
int cnum;
DOUBLE max;
int index;
if (NULL==this->resultTotal)
this->resultTotal = new DOUBLE[numClass];
if (NULL==this->resultRight)
this->resultRight = new DOUBLE[numClass];
for (i=0; i<numClass; i++)
{
this->resultTotal[i] = 0;
this->resultRight[i] = 0;
}
while (ifs.good()) {
ifs.getline(buf, bufSize);
n=0;
while (*(buf+n)!=',' && n<bufSize)
{
*(temp+n)=*(buf+n);
n++;
}
*(temp+n)=0;
cnum=this->sdata.searchClassName(temp);
if (cnum>-1)
{
this->resultTotal[cnum] += 1;
for (j=0; j<numFeature; j++)
{
n++;
int n2=n;
while (*(buf+n)!=',' && n<bufSize)
{
*(temp+n-n2)=*(buf+n);
n++;
}
*(temp+n-n2)=0;
*(sample+j)=atof(temp);
if (this->sdata.maxValue[j] > this->sdata.minValue[j])
{
*(sample+j) = (*(sample+j) - this->sdata.minValue[j])/(this->sdata.maxValue[j] - this->sdata.minValue[j]);
}
else
*(sample+j) = 0.5;
}
for (i=0; i<numClass; i++)
{
*(value+i) = -this->paramLogSigma[i];
for (j=0; j<numFeature; j++)
{
for (k=0; k<numFeature; k++)
{
*(value+i) -= (*(sample+j) - this->paramU[i*numFeature+j])*(this->paramInvSigma[(i*numFeature+j)*numFeature+k])*(*(sample+k) - this->paramU[i*numFeature+k]);
}
}
}
max = *value;
index = 0;
for (i=1; i<numClass; i++)
{
if (*(value+i)>max)
{
max = *(value+i);
index = i;
}
}
if (index == cnum)
this->resultRight[cnum] += 1;
}
}
delete[] buf;
delete[] temp;
delete[] sample;
delete[] value;
return ret;
}
DOUBLE CRda::crossValidation(char* fileName)
{
int fold = 4;
int i, j, k, index, cnum;
int n1, n2;
if (this->sdata.ydata == NULL)
{
if(!this->sdata.readFile(fileName))//读入训练数据
return 0;//如果数据格式不正确,退出程序。
this->sdata.normalize();
}
int numClass = this->sdata.numClass;
int numFeature = this->sdata.numFeature;
int numTotalSample = this->sdata.numSample;
int numSample = this->sdata.numSample/fold;
if (this->paramU == NULL)
this->paramU = new DOUBLE[numClass*numFeature];
if (this->paramSigma == NULL)
this->paramSigma = new DOUBLE[numClass*numFeature*numFeature];
if (this->paramInvSigma == NULL)
this->paramInvSigma = new DOUBLE[numClass*numFeature*numFeature];
if (this->paramLogSigma == NULL)
this->paramLogSigma = new DOUBLE[numClass];
DOUBLE *sumX = new DOUBLE[numClass*numFeature];
DOUBLE *sumXX = new DOUBLE[numClass*numFeature*numFeature];
int* classSample = new int[numClass];
DOUBLE *originalSigma = new DOUBLE[numClass*numFeature*numFeature];
DOUBLE *sigma0 = new DOUBLE[numFeature*numFeature];
DOUBLE *sigmaI = new DOUBLE[numClass];
DOUBLE* sample=new DOUBLE[numFeature];
DOUBLE* value=new DOUBLE[numClass];
DOUBLE temp, max;
int total;
int resultRight = 0;
int resultTotal = fold*numSample;
for (n1=0; n1<fold; n1++)
{
for (i=0; i<numClass; i++)
{
classSample[i]=0;
}
total = numClass * numFeature;
for (i=0; i<total; i++)
{
sumX[i]=0;
}
total = numClass * numFeature * numFeature;
for (i=0; i<total; i++)
{
sumXX[i]=0;
}
for (n2=0; n2<fold; n2++)
{
if (n2!=n1)
{
for (i=0; i<numSample; i++)
{
cnum = this->sdata.ydata[n2*numSample+i];
classSample[cnum] ++;
for (j=0; j<numFeature; j++)
{
temp = this->sdata.xdata[(n2*numSample+i)*numFeature+j];
sumX[cnum*numFeature+j] += temp;
for (k=j; k<numFeature; k++)
{
sumXX[(cnum*numFeature+j)*numFeature+k] +=temp * this->sdata.xdata[(n2*numSample+i)*numFeature+k];
}
}
}
}
}
for (i=0; i<numClass; i++)
{
for (j=0; j<numFeature; j++)
{
this->paramU[i*numFeature+j] = sumX[i*numFeature+j] / classSample[i];
}
}
for (i=0; i<numClass; i++)
{
for (j=0; j<numFeature; j++)
{
for (k=j; k<numFeature; k++)
{
index = (i*numFeature+j)*numFeature+k;
originalSigma[index] = sumXX[index]/classSample[i]-(this->paramU[i*numFeature+j]*this->paramU[i*numFeature+k]);
}
}
}
for (i=0; i<numClass; i++)
{
for (j=1; j<numFeature; j++)
{
for (k=0; k<j; k++)
{
originalSigma[(i*numFeature+j)*numFeature+k] = originalSigma[(i*numFeature+k)*numFeature+j];
}
}
}
DOUBLE temp1 = (1-this->paramBeta)*(1-this->paramGamma);
DOUBLE temp2 = this->paramBeta*(1-this->paramGamma);
int totalNum = numSample*(fold-1);
for (j=0; j<numFeature; j++)
{
for (k=0; k<numFeature; k++)
{
index = j*numFeature+k;
sigma0[index] = 0;
for (i=0; i<numClass; i++)
{
sigma0[index] += originalSigma[(i*numFeature+j)*numFeature+k] * classSample[i];
}
sigma0[index] = sigma0[index]/totalNum;
}
}
for (i=0; i<numClass; i++)
{
sigmaI[i] = 0;
for (j=0; j<numFeature; j++)
{
sigmaI[i] += originalSigma[j*numFeature+j];
}
sigmaI[i] = sigmaI[i]/numFeature;
}
for (i=0; i<numClass; i++)
{
for (j=0; j<numFeature; j++)
{
for (k=0; k<numFeature; k++)
{
index = (i*numFeature+j)*numFeature+k;
this->paramSigma[index] = temp1*originalSigma[index]+temp2*sigma0[j*numFeature+k];
}
}
}
for (i=0; i<numClass; i++)
{
temp2 = this->paramGamma*sigmaI[i];
for (j=0; j<numFeature; j++)
{
index = (i*numFeature+j)*numFeature+j;
this->paramSigma[index] += temp2;
if (this->paramSigma[index]<0.001)//限制协方差矩阵的对角元素的下限
{
this->paramSigma[index] = 0.001;
}
}
}
//模型参数预处理
CMatrix sigma;
CMatrix invSigma;
int size = numFeature * numFeature;
sigma.nline = numFeature;
sigma.ncol = numFeature;
sigma.pdata = new DOUBLE[size];
for (i=0; i<numClass; i++)
{
for (j=0; j<size; j++)
{
sigma.pdata[j] = this->paramSigma[i*size+j];
}
sigma.inv(invSigma);
for (j=0; j<size; j++)
{
this->paramInvSigma[i*size+j] = invSigma.pdata[j];
}
this->paramLogSigma[i] = log(sigma.det())-2*log(((DOUBLE)classSample[i])/((DOUBLE)(numSample*(fold-1))));//这里log表示自然对数
}
//开始测试
for (n2=0; n2<numSample; n2++)
{
cnum = this->sdata.ydata[n1*numSample+n2];
for (j=0; j<numFeature; j++)
{
*(sample+j)=this->sdata.xdata[(n1*numSample+n2)*numFeature+j];
}
for (i=0; i<numClass; i++)
{
*(value+i) = -this->paramLogSigma[i];
for (j=0; j<numFeature; j++)
{
for (k=0; k<numFeature; k++)
{
*(value+i) -= (*(sample+j) - this->paramU[i*numFeature+j])*(this->paramInvSigma[(i*numFeature+j)*numFeature+k])*(*(sample+k) - this->paramU[i*numFeature+k]);
}
}
}
max = *value;
index = 0;
for (i=1; i<numClass; i++)
{
if (*(value+i)>max)
{
max = *(value+i);
index = i;
}
}
if (index == cnum)
resultRight ++;
}
}
DOUBLE ret = ((DOUBLE)resultRight)/((DOUBLE)resultTotal);
delete[] sumX;
delete[] sumXX;
delete[] classSample;
delete[] originalSigma;
delete[] sigma0;
delete[] sigmaI;
delete[] sample;
delete[] value;
return ret;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -