📄 fda2.m
字号:
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% function fda2()
% z.li, 04-12-2004
% fisher discriminant analysis
% function dependency:
% - n/a
% input:
% x - data: d x n
% y - label: 1 x n
% output:
% A - the transform: K x d
% ev - eigen values
% lbl,nj - classes
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%function [A, ev, Sb, Sw, lbl, nj]=fda2(x, y)
function [A, ev, Sb, Sw, lbl, nj]=fda2(x, y)
dbg = 'n';
if dbg == 'y'
K = 3; m=[16, 12, 18];
styl = ['+', '.', 'x'];
axis([0 100 0 100]); hold on;
x = zeros(2, sum(m));
y = zeros(1, sum(m));
cnt=0;
for j=1:K
fprintf('\n new class: ');
for t=1:m(j)
[px py]=ginput(1);
plot(px, py, styl(j));
fprintf('%c', styl(j));
cnt=cnt+1;
x(:,cnt) = [px py]';
y(cnt) = j;
end
end
end
% const
dbgPlot = 'y';
styl = ['.', '+', 'o', 'x','.', '+', 'o', 'x'];
minNj = 2; % min number of samples per class required
% collect labeling
[dim, N] = size(x); % sample dimension and total number
M = mean(x')'; % sample mean for all
% process labels: nC, lbl, nj
z = sort(y);
lbl(1) = z(1); nj(1) = 1; nC = 1;
for j=2:N
if z(j) == lbl(nC)
nj(nC) = nj(nC)+1;
else
nC=nC+1;
lbl(nC)=z(j); nj(nC)=1;
end
end
mj = zeros(dim, nC); % class means
% compute Sw - intra class scatter: dxd
Sw = zeros(dim,dim);
% process all classes
for j=1:nC
if nj(j) > minNj
indx =[];indx = find(y==lbl(j));
nj(j) =length(indx);
xj=zeros(dim, nj(j)); xj = x(:,indx);
% compute sample mean
mj(:,j) = mean(xj')';
for t=1:nj(j)
Sw = Sw + (xj(:,t)-mj(:,j))*(xj(:,t)-mj(:,j))';
end
% plot for dbg
if dbgPlot=='y'
hold on;
for t=1:nj(j)
plot(xj(1,t), xj(2,t), styl(j));
end
end
end
end
% compute Sb - inter class scatter: dxd
Sb = zeros(dim, dim);
for j=1:nC
if nj(j) > minNj
Sb = Sb + nj(j)*(mj(j)-M)*(mj(j)-M)';
end
end
det(Sb)
det(Sw)
% find fisher discriminant subspace
[V,D]=eig(Sb, Sw);
ev1 = diag(D);
[ev, eidx]=sort(ev1);
nDim = min(nC-1, dim);
for j=1:nDim
dmj=eidx(dim-j+1);
A(:,j)=V(:,dmj);
end
% dbg plot
if dbgPlot == 'y'
hold on;
for j=1:nC
plot(mj(1,j), mj(2,j), 'or');
end
plot(M(1), M(2), 'or'); plot(M(1), M(2), '+r');
dplot=800;
px = [M(1)-dplot*A(1,1), M(1)+dplot*A(1,1)];
py = [M(2)-dplot*A(2,1), M(2)+dplot*A(2,1)];
plot(px, py, '.', px,py, '-r');
if nDim > 1
px = [M(1)-dplot*A(1,2), M(1)+dplot*A(1,2)];
py = [M(2)-dplot*A(2,2), M(2)+dplot*A(2,2)];
plot(px, py, '.', px,py, ':m');
end
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -