📄 calclda.cpp
字号:
// CalcLDA.cpp: implementation of the CCalcLDA class.
//
//////////////////////////////////////////////////////////////////////
#include "stdafx.h"
#include "FaceDV.h"
#include "CalcLDA.h"
#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif
//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////
CCalcLDA::CCalcLDA()
{
m_NumClass=0;
m_CurrClass=0;
m_nDim=0;
nHeight=0;
MainVectors=0;
nWidth=0;
m_Meanj=NULL;
m_Initialized=FALSE;
m_SampleLoaded=FALSE;
m_TotalSamples=0;
m_ClassNum=0;
m_Samples=NULL;
m_Images=NULL;
m_Images1=NULL;
m_Samples1=NULL;
}
CCalcLDA::~CCalcLDA()
{
Destroy();
}
void CCalcLDA::Init(int NumClass, int ClassNum,int height, int width)
{
m_NumClass=NumClass;
m_ClassNum=ClassNum;
nHeight=height;
nWidth=width;
if (m_Initialized)
Destroy();
m_Meanj=(CMatrix *)new CMatrix[m_NumClass];
m_Nj=(int*)new int[m_NumClass];
m_Samples=(CMatrix **)new CMatrix*[m_NumClass];
//EigenValue=(double*)new double[nWidth];
m_Initialized=TRUE;
}
void CCalcLDA::Destroy()
{
if (m_Initialized)
{
if (m_Meanj)
delete [] m_Meanj;
if (m_Nj)
delete [] m_Nj;
if (m_Samples)
delete [] m_Samples;
if (m_Images)
delete [] m_Images;
if (EigenValue)
delete [] EigenValue;
if (m_Images1)
delete [] m_Images1;
if (m_Samples1)
delete [] m_Samples;
}
}
//设置各类样本数
BOOL CCalcLDA::SetClassInfo(int Num, int count)
{
if (m_Initialized && Num==m_CurrClass)
{
m_Nj[Num]=count;
m_CurrClass++;
m_Samples[Num]=(CMatrix *)new CMatrix[count];
if (Num==(m_NumClass-1))
{
m_TotalSamples=0;
for (int i=0;i<m_NumClass;i++)
m_TotalSamples+=m_Nj[i];
m_CurrClass=0;
InitMatrix();
}
return TRUE;
}
else
return FALSE;
}
//装入第NumClass类的第Index个样本,img为样本数据(经过归一化后的样本)
BOOL CCalcLDA::LoadSample(int NumClass, double *img, int Index,int n)
{
if (!m_Initialized)
return FALSE;
m_Samples[NumClass][Index].Init(1,nHeight*nWidth);
m_Samples[NumClass][Index].SetData(img);
m_Images[n].Init(1,nHeight*nWidth);
m_Images[n].SetData(img);
if (n==m_TotalSamples)
m_SampleLoaded=TRUE;//所有样本装入完毕
return TRUE;
}
//初始化样本数据矩阵
BOOL CCalcLDA::InitMatrix()
{
m_Images=(CMatrix *)new CMatrix[m_TotalSamples];
return TRUE;
}
//归一化
void CCalcLDA::NormImage(double *img)
{
LONG i;
double max=0,min=255;
if (!img) return;
for (i=0;i<nHeight*nWidth;i++)
{
min = MIN(min, img[i]);
max = MAX(max, img[i]);
}
double delta=max-min;
for (i=0;i<nHeight*nWidth;i++)
{
img[i]=(img[i]-min)/delta;
}
}
void CCalcLDA::NormImage1(double *img)
{
LONG i;
if (!img) return;
for (i=0;i<nHeight*nWidth;i++)
img[i]=img[i]/256;
}
//求样本总均值
void CCalcLDA::CalcMean()
{
LONG i,k;//,j;
double sum=0.0;
m_Mean.Init(1,m_nDim);
for (i=0;i<m_nDim;i++)
{
//for (j=0;j<nWidth;j++)
{
sum=0.0;
for (k=0;k<m_TotalSamples;k++)
{
sum+=m_Images1[k].GetElement(0,i);
}
m_Mean.SetElement(0,i,sum/m_TotalSamples);
}
}
}
//求第J类之均值
CMatrix CCalcLDA::CalcMeanJ(int Index)
{
CMatrix result;
LONG i,k;//,j;
double sum=0.0;
result.Init(1,m_nDim);
for (i=0;i<m_nDim;i++)
{
// for (j=0;j<nWidth;j++)
{
sum=0.0;
for (k=0;k<m_Nj[Index];k++)
{
sum+=m_Samples1[Index][k].GetElement(0,i);
}
result.SetElement(0,i,sum/m_Nj[Index]);
}
}
return result;
}
//求类间离散度矩阵
void CCalcLDA::CalcSB()
{
int i;
m_Sb.Init(m_nDim,m_nDim);
for (i=0;i<m_NumClass;i++)
{
m_Sb=m_Sb+((((m_Meanj[i]-m_Mean).Transpose())*(m_Meanj[i]-m_Mean))*m_ClassNum)*m_Nj[i];
}
}
//求类内离散度矩阵
void CCalcLDA::CalcSW()
{
int i,j;
CMatrix tmp;
m_Sw.Init(m_nDim,m_nDim);
for (i=0;i<m_NumClass;i++)
{
tmp.Init(m_nDim,m_nDim);
for (j=0;j<m_Nj[i];j++)
{
tmp=tmp+(((m_Samples1[i][j]-m_Meanj[i]).Transpose())*(m_Samples1[i][j]-m_Meanj[i]));
}
m_Sw=m_Sw+tmp;
}
}
BOOL CCalcLDA::CalcEigenVec()
{
LONG i,h,w;
CalcSW();
CalcSB();
CMatrix m_eig;
if (m_Sw.InvertGaussJordan())
m_Sw=m_Sw*m_Sb;
EigenValue=(double*)new double[m_Sw.m_nNumColumns];
if (!m_Sw.JacobiEigenv2(EigenValue,EigenVector)) return FALSE;
m_Sw.SortEigen(EigenValue,EigenVector,0,0,m_Sw.m_nNumColumns-1,FALSE);
EigenVector=EigenVector.Transpose();
EigenVector=EigenVector*m_PCA.m_EigenVector;
m_eig.Init(nHeight,nWidth);
CString name;
for (i=0;i<m_nDim;i++)
{
for (h=0;h<nHeight;h++)
{
for (w=0;w<nWidth;w++)
{
m_eig.SetElement(h,w,EigenVector.GetElement(i,h*nWidth+w));
}
}
// m_eig=m_eig;//+m_PCA.m_MeanFace;
m_eig=m_eig*256;
name.Format("fisherface%d.txt",i);
//保存特征脸向量
SaveFisherFace(m_eig,name);
}
m_Mean1=m_Mean*EigenVector;
SaveFisherFace(m_Mean1,"MeanFisher.txt");
return TRUE;
}
void CCalcLDA::CalcMeanJ()
{
for (int j=0;j<m_NumClass;j++)
{
m_Meanj[j]=CalcMeanJ(j);
}
}
void CCalcLDA::SaveVector(CMatrix &m, int mode)
{
LONG i,j;
FILE *fp;
if (mode==0)
fp=fopen("Result.txt","w");
else
fp=fopen("result.txt","a");
if (fp)
{
fprintf(fp,"\n\n\n");
for (i=0;i<m.m_nNumRows;i++)
{
fprintf(fp,"\n");
for (j=0;j<m.m_nNumColumns;j++)
fprintf(fp,"%4.4f,\t",m.GetElement(i,j));
fprintf(fp,";");
}
fclose(fp);
}
}
BOOL CCalcLDA::CalcPCA()
{
CMatrix face;
CMatrix proj;
LONG i,j;
double *v=new double[nHeight*nWidth];
m_PCA.Init(m_TotalSamples,nHeight,nWidth);
for (i=0;i<m_TotalSamples;i++)
{
m_Images[i].GetRowVector(0,v);
m_PCA.m_Samples.SetRow(i,v);
m_PCA.m_Images.SetRow(i,v);
}
if (!m_PCA.GetEigenFace(m_NumClass,1))
return FALSE;
if (v) delete [] v;
m_nDim=m_PCA.MainVectors;
face.Init(1,nHeight*nWidth);
m_Samples1=(CMatrix **)new CMatrix*[m_NumClass];
m_Images1=(CMatrix *)new CMatrix[m_TotalSamples];
for (i=0;i<m_NumClass;i++)
{
m_Samples1[i]=(CMatrix *)new CMatrix[m_Nj[i]];
}
v=(double *)new double[nHeight*nWidth];
for (i=0;i<m_TotalSamples;i++)
{
m_Images[i].GetRowVector(0,v);
face.SetData(v);
face=face-m_PCA.m_Mean;
m_Images1[i]=face*m_PCA.m_EigenVector.Transpose();
}
for (i=0;i<m_NumClass;i++)
{
for (j=0;j<m_Nj[i];j++)
{
m_Samples[i][j].GetRowVector(0,v);
face.SetData(v);
face=face-m_PCA.m_Mean;
m_Samples1[i][j]=face*m_PCA.m_EigenVector.Transpose();
}
}
if (v) delete [] v;
return TRUE;
}
void CCalcLDA::SaveFisherFace(CMatrix &m, const char *name)
{
LONG i,j;
FILE *fp;
char path[255];
strcpy(path,"fisherfaces\\");
strcat(path,name);
fp=fopen(path,"w");
if (fp)
{
fprintf(fp,"\n\n\n");
for (i=0;i<m.m_nNumRows;i++)
{
fprintf(fp,"\n");
for (j=0;j<m.m_nNumColumns;j++)
fprintf(fp,"%4.4f,\t",m.GetElement(m.m_nNumRows-i-1,j));
}
fclose(fp);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -