📄 ldadesign.m
字号:
function LWt=LDADesign(n,m,sampleset,classnum,samplelabel,eignum)
% LDADESIGN will return LDA transformation matrix
% n denotes the problem's dimension
% m is the number of samples
% sampleset stores the sample data i.e.sampleset(:,j) represents the j-th
% sample
% classnum denotes the number of classes
% samplelabel(1*m) specify the class number to which the correspond sample belongs
% eignum specify the number of discriminant vectors
if nargout>1
error('Too many output arguments.');
end
if nargin~=6
error('Wrong number of input arguments.');
end
[cn,cm]=size(sampleset);
if cm~=m | cn~=n
error('Wrong input data.');
end
[cn,cm]=size(samplelabel);
if cn~=1 | cm~=m
error('Wrong input data');
end
%A,B分别存放每个类的平均值和所包含的样本的数目
A=zeros(n,classnum);
B=zeros(1,classnum);
for i=1:classnum
index=find(samplelabel==i);
for j=index
A(:,i)=A(:,i)+sampleset(:,j);
B(1,i)=B(1,i)+1;
end
A(:,i)=A(:,i)/B(1,i);
end
%求所有样本整体的平均值
average=zeros(n,1);
for j=1:m
average=average+sampleset(:,j);
end
average=average/m;
%准备求解Sb矩阵
A1=A;
for i=1:classnum
A1(:,i)=A1(:,i)-average;
end
B1=sqrt(B);
for i=1:classnum
A1(:,i)=A1(:,i)*B1(1,i);
end
%准备求解St矩阵
%for j=1:m
% sampleset(:,j)=sampleset(:,j)-average;
%end
%通过PCA降维使得Sw可逆
mark=m-classnum;
Wt=PCADesign(n,m,sampleset,mark);
A2=Wt*A1;
Sb=A2*A2';
Sb=Sb/m;
Sw=zeros(mark,mark);
for i=1:classnum
ci=zeros(n,B(1,i));
index=find(samplelabel==i);
s=1;
for j=index
ci(:,s)=sampleset(:,j)-A(:,i);
s=s+1;
end
ci1=Wt*ci;
ci2=ci1*ci1';
Sw=Sw+ci2;
end
Sw=Sw/m;
%A3=Wt*sampleset;
%St=A3*A3';
%St=St/m;
%Sw=St-Sb;
while rank(Sw)<mark
mark=mark-1;
Wt=PCADesign(n,m,sampleset,mark);
A2=Wt*A1;
Sb=A2*A2';
Sb=Sb/m;
%A3=Wt*sampleset;
%St=A3*A3';
%St=St/m;
%Sw=St-Sb;
Sw=zeros(mark,mark);
for i=1:classnum
ci=zeros(n,B(1,i));
index=find(samplelabel==i);
s=1;
for j=index
ci(:,s)=sampleset(:,j)-A(:,i);
s=s+1;
end
ci1=Wt*ci;
ci2=ci1*ci1';
Sw=Sw+ci2;
end
Sw=Sw/m;
end
%求解inv(Sw)*Sb的所有特征值和特征向量并按照从大到小的顺序排序
[v,d]=eig(inv(Sw)*Sb);
for j=1:(mark-1)
k=j;
for i=(j+1):mark
if d(i,i)>d(k,k)
k=i;
end
end
if k~=j
temp1=d(j,j);
d(j,j)=d(k,k);
d(k,k)=temp1;
temp=v(:,j);
v(:,j)=v(:,k);
v(:,k)=temp;
end
end
%确保不会取到为零的特征值
default=classnum-1;
if eignum>default
eignum=default;
end
while abs(d(eignum,eignum))< 1e-8
eignum=eignum-1;
end
%求解变换矩阵LWt
Wt1=zeros(eignum,mark);
for j=1:eignum
Wt1(j,:)=(v(:,j))';
end
LWt1=Wt1*Wt;
%如果没有指定输出参数就将变换矩阵写入文件
if nargout==1
LWt=LWt1;
else
fid=fopen('ldawt','w');
count=fwrite(fid,eignum,'short');
count=fwrite(fid,LWt1,'float');
if count~=eignum*n
error('file write error');
end
fclose(fid);
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -