⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mixedclassify.m

📁 It is for Face Recognition
💻 M
字号:
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% function mixedClassify()
%   z.li, 07-14-2004
%   mixed classfication with multiple modes of operation
% function dependency:
%   - n/a
% input:
%   sample          - to be classified m x d
%   training        - training data: n x d
%   group           - training data label: n x 1
%   type            - 
%   prior           - bias for each group
% output:
%   class           - label for each sample
%   err             - err rate of the model, inf if covar singular
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%function [class, err] = mixedClassify(sample, training, group, type, prior)
function [class, err] = mixedClassify(sample, training, group, type, prior)

if nargin < 3
    error('Requires at least three arguments.');
end

% grp2idx sorts a numeric grouping var ascending, and a string grouping
% var by order of first occurrence
[gindex,groups] = grp2idx(group);
nans = find(isnan(gindex));
if length(nans) > 0
    training(nans,:) = [];
    gindex(nans) = [];
end
ngroups = length(groups);
gsize = hist(gindex,1:ngroups);

[n,d] = size(training);
if size(gindex,1) ~= n
    error('The length of GROUP must equal the number of rows in TRAINING.');
elseif size(sample,2) ~= d
    error('SAMPLE and TRAINING must have the same number of columns.');
end
m = size(sample,1);

if nargin < 4 | isempty(type)
    type = 1; % 'linear'
elseif ischar(type)
    i = strmatch(lower(type), strvcat('linear','quadratic','mahalanobis'));
    if length(i) > 1
        error(sprintf('Ambiguous value for TYPE:  %s.', type));
    elseif isempty(i)
        error(sprintf('Unknown value for TYPE:  %s.', type));
    end
    type = i;
else
    error('TYPE must be a string.');
end

% Default to a uniform prior
if nargin < 5 | isempty(prior)
    prior = ones(1, ngroups) / ngroups;
% Estimate prior from relative group sizes
elseif ischar(prior) & ~isempty(strmatch(lower(prior), 'empirical'))
    prior = gsize(:)' / sum(gsize);
% Explicit prior
elseif isnumeric(prior)
    if min(size(prior)) ~= 1 | max(size(prior)) ~= ngroups
        error('PRIOR must be a vector one element for each group.');
    elseif any(prior < 0)
        error('PRIOR cannot contain negative values.');
    end
    prior = prior(:)' / sum(prior); % force a normalized row vector
elseif isstruct(prior)
    [pgindex,pgroups] = grp2idx(prior.group);
    ord = repmat(NaN,1,ngroups);
    for i = 1:ngroups
        j = strmatch(groups(i), pgroups(pgindex), 'exact');
        if ~isempty(j)
            ord(i) = j;
        end
    end
    if any(isnan(ord))
        error('PRIOR.group must contain all of the unique values in GROUP.');
    end
    prior = prior.prob(ord);
    if any(prior < 0)
        error('PRIOR.prob cannot contain negative values.');
    end
    prior = prior(:)' / sum(prior); % force a normalized row vector
else
    error('PRIOR must be a a vector, a structure, or the string ''empirical''.');
end

% Add training data to sample for error rate estimation
if nargout > 1
    sample = [sample; training];
    mm = m+n;
else
    mm = m;
end

gmeans = repmat(NaN, ngroups, d);
for k = 1:ngroups
    gmeans(k,:) = mean(training(find(gindex == k),:),1);
end

D = repmat(NaN, mm, ngroups);
switch type
case 1 % 'linear'
    if n <= ngroups
        error('TRAINING must have more observations than the number of groups.');
    end
    % Pooled estimate of covariance
    [Q,R] = qr(training - gmeans(gindex,:), 0);
    R = R / sqrt(n - ngroups); % SigmaHat = R'*R
    s = svd(R);
    if any(s <= eps^(3/4)*max(s))
        %fprintf('\n mixedClassify: pooled covariance not positive definite !');
        err=inf;
        class = [];
        return;
    end

    % MVN relative log posterior density, by group, for each sample
    for k = 1:ngroups
        A = (sample - repmat(gmeans(k,:), mm, 1)) / R;
        D(:,k) = log(prior(k)) - .5*sum(A .* A, 2);
    end

case {2,3} % 'quadratic' or 'mahalanobis'
    if any(gsize <= 1)
        error('Each group in TRAINING must have at least two observations.');
    end
    for k = 1:ngroups
        % Stratified estimate of covariance
        [Q,R] = qr(training(find(gindex == k),:) - repmat(gmeans(k,:), gsize(k), 1), 0);
        R = R / sqrt(gsize(k) - 1); % SigmaHat = R'*R
        s = svd(R);
        if any(s <= eps^(3/4)*max(s))
            %fprintf('\n mixedClassify: group covariance not positive definite.');
            err = inf; class=[];
            return;
        end

        A = (sample - repmat(gmeans(k,:), mm, 1)) / R;
        switch type
        case 2 % 'quadratic'
            % MVN relative log posterior density, by group, for each sample
            D(:,k) = log(prior(k)) - .5*(sum(A .* A, 2) + log(prod(diag(R))^2));
        case 3 % 'mahalanobis'
            % Negative squared Mahalanobis distance, by group, for each
            % sample.  Prior probabilities are not used
            D(:,k) = -sum(A .* A, 2);
        end
    end
end

% find nearest group to each observation in sample data
[tmp class] = max(D, [], 2);

% Compute apparent error rate: percentage of training data that
% are misclassified.
if nargout > 1
    trclass = class(m+(1:n));
    class = class(1:m);
    
    miss = trclass ~= gindex;
    e = repmat(NaN,ngroups,1);
    for k = 1:ngroups
        e(k) = sum(miss(find(gindex == k))) / gsize(k);
    end
    err = prior*e;
end

% Convert back to original grouping variable
if isnumeric(group)
    groups = str2num(char(groups));
    class = groups(class);
elseif ischar(group)
    groups = char(groups);
    class = groups(class,:);
else %if iscellstr(group)
    class = groups(class);
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -