📄 mixed_gaussian_classify_mfold.m
字号:
function [confmat,acc,accstd,weight, mu, sigma] = Mixed_Gaussian_classify_Mfold(data,labels,M,N)
% function [confmat,acc,accstd]=linear_classify_Mfold(data,labels,M,covflag)
% Matlab function to do Gaussian model classification over entire set
% M-fold crossvalidation for training/test, balanced by class
% Inputs: data matrix must be n (instances) x d (features)
% labels must be n (instances) x 1, with the class
% labels as unique integers 1,2,...,c
% M is integer number of folds (make sure you have enough data)
% covflag is 1 for full covariance, 0 for diag. covariance
% Outputs: confmat is confusion matrix
% acc is overall accuracy
% accstd is standard deviation of accuracy across folds
% Author: Song Siming
if (nargin<3)
M=2; % default is 2-fold cross validation
end
if (nargin<4)
covflag=1; % default is full covariance model
end
ntot=size(data,1); % Total number of examples
feature = size(data,2);
c=max(labels); % Labels must be integers 1..c
acc = zeros(1,M);
confmat = cell(1,M);
testdata = cell(1,c);
traindata = cell(1,c);
% Split dataset into M class-balanced folds
% After this, to pull out a specific label i and fold j, just use
% "labelids=find(labels==i); foldids=labelids(folds(labelids)==j); "
folds=zeros(ntot,1);
for i=1:c
labelids=find(labels==i); % indices of class i examples
n=length(labelids); % n=number of class i examples
if (n<M)
error('Not enough examples in class %d to support %d-fold crossvalidation.',i,M);
end
if (n<3*M)
warning('Warning: Fewer than 3 examples per fold in class %d.',i);
end
permutedids=labelids(randperm(n)); % permuted indices
startinds=1+floor(n*(0:(M-1))/M); % starting points of folds
endinds=floor(n*(1:M)/M); % ending points of folds
for m=1:M
folds(permutedids(startinds(m):endinds(m)))=m; % assign folds
end
end
% Run classification with each fold used once as test data
pw = zeros(1,c);
for m=1:M
confmat{1,m}=zeros(c,c);
for i=1:c % select train/test data
labelids=find(labels==i);
testids=labelids(folds(labelids)==m);
trainids=labelids(folds(labelids)~=m);
testdata{i}=data(testids,:);
traindata{i}=data(trainids,:);
pw(i) = length(traindata{i});
end
pw = pw/sum(pw);
[weight mu sigma] = train(traindata,N,c);
[confmat{1,m} acc(m)] = test(traindata,weight,mu,sigma,c,N,pw);
end
accstd = std(acc);
end
function [confmat acc] = test(data,weight,mu,sigma,c,N,pw)
p = cell(1,c);
confmat = zeros(c,c);
Index = cell(1,c);
for ii = 1:c
testdatabyClass = data{ii};
len = length(testdatabyClass(:,1));
p{ii} = zeros(len,c);
for kk = 1:c
for jj = 1:N
p{ii}(:,kk) = p{ii}(:,kk) + pwk(mu{kk}(jj,:), sigma{jj,kk}, testdatabyClass)*weight{kk}(jj)*pw(kk);
end
end
[a Index{ii}] = max(p{ii},[],2);
end
confmat(1,1) = sum(Index{1}==1);
confmat(2,2) = sum(Index{2}==2);
confmat(1,2) = sum(Index{1}==2);
confmat(2,1) = sum(Index{2}==1);
acc = (confmat(1,1)+confmat(2,2))/(confmat(1,2)+confmat(2,1)+confmat(1,1)+confmat(2,2));
end
function [weight mu sigma]=train(data,N,c)
theta = 0.0001;
feature = length(data{1}(1,:));
weight = cell(1,c);
mu = cell(1,c);
sigma = cell(N,c);
classified_data = cell(1,N);
for jj = 1:c
lenofdata = length(data{jj}(:,1));
weight{jj} = ones(1,N)/N;
mu{jj} = zeros(N,feature);
meanmu = mean(data{jj});
sig = cov(data{jj});
standard = std(data{jj});
[aa bb] = size(mu{jj}(1,:));
if N == 4
vars = [ones([aa bb])*2.2;-ones([aa bb])*2.2;(randi(2,aa,bb)-1.5);(randi(2,aa,bb)-1.5)*2];
else
vars = (randi(2,aa,bb)-1.5)*2;
vars = repmat(vars,N,1);
end
for ii = 1:N
mu{jj}(ii,:) = meanmu+standard.*vars(ii,:);
sigma{ii,jj} = sig;%+bsxfun(@times,diag(standard),rand(feature,feature)*1.2-0.6);
end
weight_old1 = 100000000;
weight_old2 = 100000000;
weight_old3 = 100000000;
[a b] = size(sigma{1,jj});
sigmamama = zeros([a*N b]);
for ii = 1:N
sigmamama = [sigmamama;sigma{ii,jj}];
end
p = zeros(lenofdata,N);
while (abs(norm(weight_old1)-norm(sigmamama)) > theta) ...
|| (abs(norm(weight_old2)-norm(mu{jj})) > theta) ...
|| (abs(norm(weight_old3)-norm(weight{jj})) > theta)
[a b] = size(sigma{1,jj});
sigmamama = zeros([a*N b]);
for ii = 1:N
p(:,ii) = pwk(mu{jj}(ii,:),sigma{ii,jj},data{jj});
sigmamama = [sigmamama;sigma{ii,jj}];
end
weight_old3 = weight{jj};
weight_old2 = mu{jj};
weight_old1 = sigmamama;
[value Index] = max(p,[],2);
for ii = 1:N
x = data{jj}(Index == ii,:);
pp = p(Index==ii,ii);
Accw = sum(pp);
n = length(x(:,1));
if Accw~=0 && n>feature;
weight{jj}(ii) = Accw;
mu{jj}(ii,:) = pp'*x/Accw;
% mu{jj}(ii,:) = mean(x);
tmp = bsxfun(@minus,x,mu{jj}(ii,:));
tmp = tmp'*tmp/n;
sigma{ii,jj} = sigmean(tmp,pp)/Accw;
else
% warning('No element in this model %f',Accw);
weight{jj}(ii) = 0;
end
end
weightadd = sum(weight{jj});
weight{jj} = weight{jj}/weightadd;
end
end
end
function sigma = sigmean(tmp,pp)
len = length(pp);
sigma = zeros(size(tmp));
for ii = 1:len
sigma = sigma+pp(ii)*tmp;
end
end
function p = pwk(mu, sigma, x)
d = length(mu);
value = bsxfun(@minus, x,mu);
i=0; dimen=size(sigma,1);
while (rcond(sigma)<10^-12)
sigma=sigma+eye(dimen)*eps*10^i;
i=i+1;
end
p = -.5*value*inv(sigma).*value;
p1 = sum(p,2);
if sum(isnan(p1)) ||sum(isinf(p1))
warning('NAN!');
end
p = 1/(2*pi)^(d/2)/sqrt(abs(det(sigma)))*exp(p1);
if sum(isnan(p)) ||sum(isinf(p))
warning('NAN!');
end
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -