⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lda.m

📁 Implementation to linear, quadratic and logistic discriminant analysis, for examples
💻 M
字号:
function [f, c, post] = lda(X, k, prior, est, nu)%LDA Linear Discriminant Analysis.%   F = LDA(X, K, PRIOR) returns a linear discriminant analysis object%   F based on the feature matrix X, class indeces in K and the prior%   probabilities in PRIOR where PRIOR is optional. See the help for%   LDA's parent object CLASSIFIER for information on the input%   arguments X, K and PRIOR.%%   In addition to the fields defined by the CLASSIFIER class, F%   contains the following fields:%%   MEANS: a g by p matrix where g is the number of classes and p is%   the number of features or variates. Each row gives the mean vector%   for each class. %%   SCALE: the p by p matrix which transforms the observed%   within-groups covariance to identity. This is equivalent to%   INV(CHOL(COV(X - F.MEANS(K, :), 1))) for maximum-likelihood%   estimates (see below). For unbiased estimates,%   INV(CHOL(COV(F.MEANS(K, :)))) will differ from F.SCALE by a factor%   of SQRT((n-1)/(n-g)) because the two normalize on a different%   number of estimated means.%%   EST: either 0, 1, or 't' representing unbiased, maximum likelihood%   or t-parameter estimation respectively as explained below.%%   NU: This field is only present if EST is 't'. NU gives the degrees%   of freedom for the t-parameter estimation as explained in the next%   paragraph.%%   LDA(X, K, PRIOR, EST, NU) where EST is one of 'unbiased', 'ml', or%   't', uses either bias-corrected, maximum likelihood or t-parameter%   estimation respectively. For t-parameter estimation, an additional%   argument, NU, gives the degrees of freedom for the estimator (the%   default is 5 if not given). The default estimator is unbiased%   estimation (which corresponds to the default for the MATLAB%   functions STD and COV). Unbiased estimation bias corrects the%   estimate of the within-groups covariance matrix by a factor of%   1/(n-g). For maximum likelihood estimation, no correction is%   made. For t-parameter estimation, the means and scale matrix are%   estimated by an iterative weighted algorithm. When specifying EST,%   only the first few disambiguating letters need be given: i.e.,%   'u', 'm' or 't'.%%   LDA(X, K, EST) is equivalent to LDA(X, K, [], EST).%%   LDA(X, K, OPTS) allows optional arguments to be passed in the%   fields of the structure OPTS. Fields that are used by LDA are%   PRIOR, EST and NU.%%   [F, C, POST] = LDA(X, K, ...) additionally performs leave-one-out%   cross-validation on the data in X. C is a length n index vector of%   estimated class memberships similar to K corresponding to the%   matrix of features X. POST is an n by g matrix of posterior%   probabilities. Leave-one-out cross-validation is only defined for%   methods 'ml' and 'unbiased'. C and POST will not necessarily%   correspond to the output of CROSSVAL(X, K, 'lda', ...) because in%   the latter, the prior probabilities are not fixed between%   cross-validation estimates unless this is done so explicitly in%   the option struct passed to CROSSVAL.%%   F = LDA(G) where G is an object of class QDA returns an LDA object%   based on G.%%   See also CLASSIFIER, QDA, LOGDA, SOFTMAX, COV, CROSSVAL.%%   Example:%   %generate artificial data with 4 classes and 3 variates%   r = randn(3);%   C = r'*r;     % random positive definite symmetric matrix%   M = randn(4, 3)*2;    % random means%   k = ceil(rand(400, 1)*4);    % random class indeces%   X = randn(400, 3)*chol(C) + M(k, :);  %   f = lda(X, k); disp(f)%   [lambda ratio] = cvar(f)   % canonical variates%   cov(f), plotcov(f)%   plotcov(shrink(f, .5))%   [c post] = classify(f, X);%   confmat(k, c)%   confmat(k, post)%%   References: %   B. D. Ripley (1996) Pattern Classification and Neural%   Networks. Cambridge.%   Copyright (c) 1999 Michael Kiefte.  %   $Log$if isa(X, 'qda')  error(nargchk(1, 1, nargin))  g = shrink(X, 1);  if g.est == 't'    warning(['Not an exact conversion for t-estimator QDA' ...	     ' objects.'])    nu = g.nu;  else    nu = [];  end  f = class(struct('means', g.means, 'scale', g.scale(:,:,1), ...		   'est', g.est, 'nu', nu), 'lda', g.classifier);  returnenderror(nargchk(2, 5, nargin))if nargin > 2 & isstruct(prior)  if nargin > 3    error(sprintf(['Cannot have arguments following option struct:\n' ...		   '%s'], nargchk(3, 3, 4)))      end  [prior est nu] = parseopt(prior, 'prior', 'est', 'nu');elseif nargin < 5  nu = [];  if nargin < 4    est = [];    if nargin < 3      prior = [];     end  endendif ischar(prior)  nu = est;  est = prior;  prior = [];end[h G] = classifier(X, k, prior);[n p] = size(X);nj = h.counts;g = length(nj);prior = h.prior;if nargout > 1  cv = 1;else  cv = 0;endif isempty(est)  est = 0;elseif ~ischar(est) | length(est) ~= size(est, 2) | ...       size(est, 1) ~= 1  error('EST must be a string.')else  t = find(strncmp(est, {'unbiased', 'ml', 't'}, length(est)));  if isempty(t)    error('EST must be one of ''unbiased'', ''ml'', or ''t''.')  end  switch t   case 1    est = 0;   case 2    est = 1;   otherwise    est = 't';  endendif est == 't'  if isempty(nu)    nu = 5;  elseif ~isa(nu, 'double') | length(nu) ~= 1 | round(nu) ~= nu | ...	nu < 3 | isinf(nu)    error(['Degrees of freedom NU must be a finite, integer scalar' ...	   ' greater than 2.'])  elseif cv    error('Cannot perform cross-validation with t-estimator.')  endelseif ~isempty(nu)    error('May specify degrees of freedom NU only with t-estimator.')  endM = sparse(1:g, 1:g, 1./nj)*G'*X;Xc = X - M(k, :);S = std(Xc);if any(S < n*max(S)*eps)  error(sprintf(['Column %d in feature matrix X is constant within' ...		 ' groups.'], min(find(S < n*max(S)*eps)))) endS = diag(S);switch(est) case {0, 1}  [u s v] = svd(Xc*S/sqrt(n - g*(1-est)), 0);  r = sum(diag(s) > n*s(1)*eps);  if (r < p)    warning(sprintf(['Nullity of within-groups covariance matrix is' ' %d.'], p - r))    v = v(:,1:r);    s = s(1:r,1:r);  end  aa = s*v';  [Qbb,Rbb] = qr(aa);  cc = triu(Qbb);    S = S*inv(cc);    if cv    Xs = X*S;    Ms = M*S;    XM = Xs - Ms(k, :);    nc = nj(k)';    K = g*(1-est);    c = (n - K - 1)/(n - K);    D = repmat(sum(Xs.^2, 2), 1, g) - 2*Xs*Ms' + repmat(sum(Ms'.^2), n, 1);    Dc = D((k-1)*n+(1:n)');    cc = (n - K)*(nc - 1)./nc;    D = c * (D + (repmat(sum(Xs .* XM, 2), 1, g) - XM*Ms').^2 ./ ...	      repmat(cc - Dc, 1, g));    D((k-1)*n+(1:n)') = Dc * c .* (nc./(nc - 1)).^2 ./ (1 - Dc./cc);    D = D/2 - repmat(log(prior), n, 1);    [y c] = min(D, [], 2);    if nargout > 2      D = exp(y(:, ones(1, g)) - D);      post = D./repmat(sum(D, 2), 1, g);    end  end otherwise  w = ones(n,1);  c = (nu+p)/(n*nu);  sing = 0;  while 1    wold = w;    [u s v] = svd(repmat(sqrt(w*c), 1, p).*Xc*S, 0);    r = sum(diag(s) > n*s(1)*eps);    if r < p      if ~sing	warning(sprintf(['Nullity of within-groups covariance matrix is' ...			 ' %d.'], p - r))		sing = 1;      end      v = v(:, 1:r);      s = s(1:r, 1:r);    end    w = 1./(1+(Xc*S*v/s).^2*repmat(1/nu, p, 1));    M = G'*(w(:,ones(1,p)).*X)./repmat(G'*w, 1, p);    if max(abs(w - wold)) < max(w)*n*eps       break    end    Xc = X - M(k, :);  end  S = S*inv(triu(qr(s*v')));endf = class(struct('means', M, 'scale', S, 'est', est, 'nu', nu), ...	  'lda', h);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -