📄 lda.m
字号:
function [f, c, post] = lda(X, k, prior, est, nu)%LDA Linear Discriminant Analysis.% F = LDA(X, K, PRIOR) returns a linear discriminant analysis object% F based on the feature matrix X, class indeces in K and the prior% probabilities in PRIOR where PRIOR is optional. See the help for% LDA's parent object CLASSIFIER for information on the input% arguments X, K and PRIOR.%% In addition to the fields defined by the CLASSIFIER class, F% contains the following fields:%% MEANS: a g by p matrix where g is the number of classes and p is% the number of features or variates. Each row gives the mean vector% for each class. %% SCALE: the p by p matrix which transforms the observed% within-groups covariance to identity. This is equivalent to% INV(CHOL(COV(X - F.MEANS(K, :), 1))) for maximum-likelihood% estimates (see below). For unbiased estimates,% INV(CHOL(COV(F.MEANS(K, :)))) will differ from F.SCALE by a factor% of SQRT((n-1)/(n-g)) because the two normalize on a different% number of estimated means.%% EST: either 0, 1, or 't' representing unbiased, maximum likelihood% or t-parameter estimation respectively as explained below.%% NU: This field is only present if EST is 't'. NU gives the degrees% of freedom for the t-parameter estimation as explained in the next% paragraph.%% LDA(X, K, PRIOR, EST, NU) where EST is one of 'unbiased', 'ml', or% 't', uses either bias-corrected, maximum likelihood or t-parameter% estimation respectively. For t-parameter estimation, an additional% argument, NU, gives the degrees of freedom for the estimator (the% default is 5 if not given). The default estimator is unbiased% estimation (which corresponds to the default for the MATLAB% functions STD and COV). Unbiased estimation bias corrects the% estimate of the within-groups covariance matrix by a factor of% 1/(n-g). For maximum likelihood estimation, no correction is% made. For t-parameter estimation, the means and scale matrix are% estimated by an iterative weighted algorithm. When specifying EST,% only the first few disambiguating letters need be given: i.e.,% 'u', 'm' or 't'.%% LDA(X, K, EST) is equivalent to LDA(X, K, [], EST).%% LDA(X, K, OPTS) allows optional arguments to be passed in the% fields of the structure OPTS. Fields that are used by LDA are% PRIOR, EST and NU.%% [F, C, POST] = LDA(X, K, ...) additionally performs leave-one-out% cross-validation on the data in X. C is a length n index vector of% estimated class memberships similar to K corresponding to the% matrix of features X. POST is an n by g matrix of posterior% probabilities. Leave-one-out cross-validation is only defined for% methods 'ml' and 'unbiased'. C and POST will not necessarily% correspond to the output of CROSSVAL(X, K, 'lda', ...) because in% the latter, the prior probabilities are not fixed between% cross-validation estimates unless this is done so explicitly in% the option struct passed to CROSSVAL.%% F = LDA(G) where G is an object of class QDA returns an LDA object% based on G.%% See also CLASSIFIER, QDA, LOGDA, SOFTMAX, COV, CROSSVAL.%% Example:% %generate artificial data with 4 classes and 3 variates% r = randn(3);% C = r'*r; % random positive definite symmetric matrix% M = randn(4, 3)*2; % random means% k = ceil(rand(400, 1)*4); % random class indeces% X = randn(400, 3)*chol(C) + M(k, :); % f = lda(X, k); disp(f)% [lambda ratio] = cvar(f) % canonical variates% cov(f), plotcov(f)% plotcov(shrink(f, .5))% [c post] = classify(f, X);% confmat(k, c)% confmat(k, post)%% References: % B. D. Ripley (1996) Pattern Classification and Neural% Networks. Cambridge.% Copyright (c) 1999 Michael Kiefte. % $Log$if isa(X, 'qda') error(nargchk(1, 1, nargin)) g = shrink(X, 1); if g.est == 't' warning(['Not an exact conversion for t-estimator QDA' ... ' objects.']) nu = g.nu; else nu = []; end f = class(struct('means', g.means, 'scale', g.scale(:,:,1), ... 'est', g.est, 'nu', nu), 'lda', g.classifier); returnenderror(nargchk(2, 5, nargin))if nargin > 2 & isstruct(prior) if nargin > 3 error(sprintf(['Cannot have arguments following option struct:\n' ... '%s'], nargchk(3, 3, 4))) end [prior est nu] = parseopt(prior, 'prior', 'est', 'nu');elseif nargin < 5 nu = []; if nargin < 4 est = []; if nargin < 3 prior = []; end endendif ischar(prior) nu = est; est = prior; prior = [];end[h G] = classifier(X, k, prior);[n p] = size(X);nj = h.counts;g = length(nj);prior = h.prior;if nargout > 1 cv = 1;else cv = 0;endif isempty(est) est = 0;elseif ~ischar(est) | length(est) ~= size(est, 2) | ... size(est, 1) ~= 1 error('EST must be a string.')else t = find(strncmp(est, {'unbiased', 'ml', 't'}, length(est))); if isempty(t) error('EST must be one of ''unbiased'', ''ml'', or ''t''.') end switch t case 1 est = 0; case 2 est = 1; otherwise est = 't'; endendif est == 't' if isempty(nu) nu = 5; elseif ~isa(nu, 'double') | length(nu) ~= 1 | round(nu) ~= nu | ... nu < 3 | isinf(nu) error(['Degrees of freedom NU must be a finite, integer scalar' ... ' greater than 2.']) elseif cv error('Cannot perform cross-validation with t-estimator.') endelseif ~isempty(nu) error('May specify degrees of freedom NU only with t-estimator.') endM = sparse(1:g, 1:g, 1./nj)*G'*X;Xc = X - M(k, :);S = std(Xc);if any(S < n*max(S)*eps) error(sprintf(['Column %d in feature matrix X is constant within' ... ' groups.'], min(find(S < n*max(S)*eps)))) endS = diag(S);switch(est) case {0, 1} [u s v] = svd(Xc*S/sqrt(n - g*(1-est)), 0); r = sum(diag(s) > n*s(1)*eps); if (r < p) warning(sprintf(['Nullity of within-groups covariance matrix is' ' %d.'], p - r)) v = v(:,1:r); s = s(1:r,1:r); end aa = s*v'; [Qbb,Rbb] = qr(aa); cc = triu(Qbb); S = S*inv(cc); if cv Xs = X*S; Ms = M*S; XM = Xs - Ms(k, :); nc = nj(k)'; K = g*(1-est); c = (n - K - 1)/(n - K); D = repmat(sum(Xs.^2, 2), 1, g) - 2*Xs*Ms' + repmat(sum(Ms'.^2), n, 1); Dc = D((k-1)*n+(1:n)'); cc = (n - K)*(nc - 1)./nc; D = c * (D + (repmat(sum(Xs .* XM, 2), 1, g) - XM*Ms').^2 ./ ... repmat(cc - Dc, 1, g)); D((k-1)*n+(1:n)') = Dc * c .* (nc./(nc - 1)).^2 ./ (1 - Dc./cc); D = D/2 - repmat(log(prior), n, 1); [y c] = min(D, [], 2); if nargout > 2 D = exp(y(:, ones(1, g)) - D); post = D./repmat(sum(D, 2), 1, g); end end otherwise w = ones(n,1); c = (nu+p)/(n*nu); sing = 0; while 1 wold = w; [u s v] = svd(repmat(sqrt(w*c), 1, p).*Xc*S, 0); r = sum(diag(s) > n*s(1)*eps); if r < p if ~sing warning(sprintf(['Nullity of within-groups covariance matrix is' ... ' %d.'], p - r)) sing = 1; end v = v(:, 1:r); s = s(1:r, 1:r); end w = 1./(1+(Xc*S*v/s).^2*repmat(1/nu, p, 1)); M = G'*(w(:,ones(1,p)).*X)./repmat(G'*w, 1, p); if max(abs(w - wold)) < max(w)*n*eps break end Xc = X - M(k, :); end S = S*inv(triu(qr(s*v')));endf = class(struct('means', M, 'scale', S, 'est', est, 'nu', nu), ... 'lda', h);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -