📄 mixedclassify.m
字号:
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% function mixedClassify()
% z.li, 07-14-2004
% mixed classfication with multiple modes of operation
% function dependency:
% - n/a
% input:
% sample - to be classified m x d
% training - training data: n x d
% group - training data label: n x 1
% type -
% prior - bias for each group
% output:
% class - label for each sample
% err - err rate of the model, inf if covar singular
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%function [class, err] = mixedClassify(sample, training, group, type, prior)
function [class, err] = mixedClassify(sample, training, group, type, prior)
if nargin < 3
error('Requires at least three arguments.');
end
% grp2idx sorts a numeric grouping var ascending, and a string grouping
% var by order of first occurrence
[gindex,groups] = grp2idx(group);
nans = find(isnan(gindex));
if length(nans) > 0
training(nans,:) = [];
gindex(nans) = [];
end
ngroups = length(groups);
gsize = hist(gindex,1:ngroups);
[n,d] = size(training);
if size(gindex,1) ~= n
error('The length of GROUP must equal the number of rows in TRAINING.');
elseif size(sample,2) ~= d
error('SAMPLE and TRAINING must have the same number of columns.');
end
m = size(sample,1);
if nargin < 4 | isempty(type)
type = 1; % 'linear'
elseif ischar(type)
i = strmatch(lower(type), strvcat('linear','quadratic','mahalanobis'));
if length(i) > 1
error(sprintf('Ambiguous value for TYPE: %s.', type));
elseif isempty(i)
error(sprintf('Unknown value for TYPE: %s.', type));
end
type = i;
else
error('TYPE must be a string.');
end
% Default to a uniform prior
if nargin < 5 | isempty(prior)
prior = ones(1, ngroups) / ngroups;
% Estimate prior from relative group sizes
elseif ischar(prior) & ~isempty(strmatch(lower(prior), 'empirical'))
prior = gsize(:)' / sum(gsize);
% Explicit prior
elseif isnumeric(prior)
if min(size(prior)) ~= 1 | max(size(prior)) ~= ngroups
error('PRIOR must be a vector one element for each group.');
elseif any(prior < 0)
error('PRIOR cannot contain negative values.');
end
prior = prior(:)' / sum(prior); % force a normalized row vector
elseif isstruct(prior)
[pgindex,pgroups] = grp2idx(prior.group);
ord = repmat(NaN,1,ngroups);
for i = 1:ngroups
j = strmatch(groups(i), pgroups(pgindex), 'exact');
if ~isempty(j)
ord(i) = j;
end
end
if any(isnan(ord))
error('PRIOR.group must contain all of the unique values in GROUP.');
end
prior = prior.prob(ord);
if any(prior < 0)
error('PRIOR.prob cannot contain negative values.');
end
prior = prior(:)' / sum(prior); % force a normalized row vector
else
error('PRIOR must be a a vector, a structure, or the string ''empirical''.');
end
% Add training data to sample for error rate estimation
if nargout > 1
sample = [sample; training];
mm = m+n;
else
mm = m;
end
gmeans = repmat(NaN, ngroups, d);
for k = 1:ngroups
gmeans(k,:) = mean(training(find(gindex == k),:),1);
end
D = repmat(NaN, mm, ngroups);
switch type
case 1 % 'linear'
if n <= ngroups
error('TRAINING must have more observations than the number of groups.');
end
% Pooled estimate of covariance
[Q,R] = qr(training - gmeans(gindex,:), 0);
R = R / sqrt(n - ngroups); % SigmaHat = R'*R
s = svd(R);
if any(s <= eps^(3/4)*max(s))
%fprintf('\n mixedClassify: pooled covariance not positive definite !');
err=inf;
class = [];
return;
end
% MVN relative log posterior density, by group, for each sample
for k = 1:ngroups
A = (sample - repmat(gmeans(k,:), mm, 1)) / R;
D(:,k) = log(prior(k)) - .5*sum(A .* A, 2);
end
case {2,3} % 'quadratic' or 'mahalanobis'
if any(gsize <= 1)
error('Each group in TRAINING must have at least two observations.');
end
for k = 1:ngroups
% Stratified estimate of covariance
[Q,R] = qr(training(find(gindex == k),:) - repmat(gmeans(k,:), gsize(k), 1), 0);
R = R / sqrt(gsize(k) - 1); % SigmaHat = R'*R
s = svd(R);
if any(s <= eps^(3/4)*max(s))
%fprintf('\n mixedClassify: group covariance not positive definite.');
err = inf; class=[];
return;
end
A = (sample - repmat(gmeans(k,:), mm, 1)) / R;
switch type
case 2 % 'quadratic'
% MVN relative log posterior density, by group, for each sample
D(:,k) = log(prior(k)) - .5*(sum(A .* A, 2) + log(prod(diag(R))^2));
case 3 % 'mahalanobis'
% Negative squared Mahalanobis distance, by group, for each
% sample. Prior probabilities are not used
D(:,k) = -sum(A .* A, 2);
end
end
end
% find nearest group to each observation in sample data
[tmp class] = max(D, [], 2);
% Compute apparent error rate: percentage of training data that
% are misclassified.
if nargout > 1
trclass = class(m+(1:n));
class = class(1:m);
miss = trclass ~= gindex;
e = repmat(NaN,ngroups,1);
for k = 1:ngroups
e(k) = sum(miss(find(gindex == k))) / gsize(k);
end
err = prior*e;
end
% Convert back to original grouping variable
if isnumeric(group)
groups = str2num(char(groups));
class = groups(class);
elseif ischar(group)
groups = char(groups);
class = groups(class,:);
else %if iscellstr(group)
class = groups(class);
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -