⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lans_classifier.m

📁 模式识别工具包
💻 M
字号:
%	lans_classifier	- General LANS classifier framework
%
%	[result,classifier]	= lans_classifier(pdata,classifier<,options><,state>)
%
%	_____OUTPUTS____________________________________________________________
%	result	classification result				(cell/structure)
%		result{p}
%		.out		model output labels		(row vector)
%		.confuse	confusion matrix		(matrix)
%		.confuseidx	indices of misclassified data	(2-D cell)
%			e.g.	confuseidx{2,3}	= [2 3 4] refers to samples with
%				true class 2 but misclassified as 3
%		.cerr
%		.rejectidx	indices of rejected data	(row vector)
%				(for likelihood based classification)
%		.rejectrate	% of data that got rejected	(scalar)
%
%		For training without forward pass
%		.M		# nodes				(scalar)
%		.L		# latent basis			(scalar)
%		
%   classifier	trained/initial classifier			(structure)
%
%	_____INPUTS_____________________________________________________________
%	pdata	partitioned LANS formatted dataset		(cell/structure)
%		pdata{p}	during training
%		see lans_load
%  classifier	classifier model (trained/untrained)		(structure)
%		if [], defaults to training mode 
%     options	classifier (top-level) options			(string)
%		-iter	# iterations for each model
%		-mode	classifier mode
%			train*	training a classifier from scratch or
%				revising a previously trained classifier
%			test	forward pass pdata{3} to classifier	
%		-model	classifier model
%			pps	Probabilistic Principal Surface Classifier
%			knn*	Nearest neighbor
%			gmm	Gaussian Mixture Model (1 mixture per class)
%				-M1	# nodes			(scalar)
%					10*
%					set to min(10,N)
%				-covtype covariance type	(string)
%					'spherical'*
%					'diag'
%					'full'
%
%		-class	type of classification (restricted by each -model)
%			model	class
%			-----	-----
%			gmm
%				1nn*	Nearest data space node
%				ml	Maximum Likelihood (provides rejection)
%				
%			knn
%				1nn*	1 nearest neighbor
%				{k}nn	k-nearest neighbor k = +ve integer
%					(results for 1:k returned)
%			pps
%				1nn*	Nearest data space node
%				ml	Maximum Likelihood (provides rejection)
%				grid	Nearest grid in data space
%
%		-forward	whether to classify all supplied datasets
%			{0,1*}	using trained classifier (in training mode)
%		-confuse whether to compute confusion matrix
%			{0*,1}
%
%	state	state used for random initialization		(col vector)
%		[]*
%
%		* default
%
%	_____EXAMPLE____________________________________________________________
%
%	_____NOTES______________________________________________________________
%	- for demo, call function without parameters
%	- where available 
%		p =	1	Train
%			2	Validation	(useful only during training)
%			3	Test
%	- For testing,
%		pdata & result are NOT cells
%	- -class options restricted by -model
%	- adjust KNN_THRESH for partitioned KNN (use less memory)
%	- works only for integer numbered class
%	  non-integer or non-consecutive or cell-type classes must be
%	  preprocessed by lans_class2cidx
%	- gmminit initializes sigma to 1, which will fail for high-D data, so
%	  here we initialize it to mean distance to nearest neighbor
%	
%
%	_____SEE ALSO___________________________________________________________
%	lans_class2idx	lans_partition	lans_confuse	lans_load
%	NETLAB:	gmminit gmmem gmmpost
%
%	(C) 2000.06.29 Kui-yu Chang
%	http://lans.ece.utexas.edu/~kuiyu

%	This program is free software; you can redistribute it and/or modify
%	it under the terms of the GNU General Public License as published by
%	the Free Software Foundation; either version 2 of the License, or
%	(at your option) any later version.
%
%	This program is distributed in the hope that it will be useful,
%	but WITHOUT ANY WARRANTY; without even the implied warranty of
%	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%	GNU General Public License for more details.
%
%	You should have received a copy of the GNU General Public License
%	along with this program; if not, write to the Free Software
%	Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA
%	or check
%			http://www.gnu.org/

%	_____TO DO______________________________________________________________
%	- procedure to save partial result to disk
%	- scale covars for non-isotropic gmm
%	- MSE info. in result
%	- string in result
%	- handle 'real' test data with no .out entries!
%	- reject data too far away from any classes for ML method
%	- discard the use of foptions

function [result,classifier] = lans_classifier(pdata,classifier,options,state)

KNN_THRESH	= 20*1024*1024;	% matrix size limit (MB) for KNN

if nargin>0
%__________ REGULAR ____________________________________________________________
TOL_ML	= .1;		% likelihood tolerance
if nargin<4
	state	= [];	% use current state
	if nargin<3
		options	= [];
	end
end

mode	= paraget('-mode',options,'train');
model	= paraget('-model',options,'knn');
class	= paraget('-class',options,'1nn');
forward	= paraget('-forward',options,1);
confuse	= paraget('-confuse',options,0);

showtext= paraget('-showtext',options,0);

if strcmp(mode,'train')
	%________________________________________	train
	data	= pdata{1};
	[D,N]	= size(data.in);			% # training samples
	C	= length(lans_finduniq(sort(data.out)));% # classes

	switch lower(model)
	case 'gmm'
		plotpps	= paraget('-plotpps',options,0);
		plotdata= paraget('-plotdata',options,0);
		covtype	= paraget('-covartype',options,'spherical');

		% set gmm options
		gmm_o	= foptions;
		gmm_o(1)= -1;
		gmm_o(5)= 1;				% prevent zero variances
		gmm_o(14)=paraget('-iter',options,20);

		trgdata	= lans_group(pdata{1});
		for c=1:C
			if showtext
				pstr	= sprintf('GMM:Training class %s (%d)',num2str(trgdata{c}.out(:,1)),c);
				disp(pstr);
			end
			y	= trgdata{c}.in';
			[N,D]	= size(y);
			M	= paraget('-M1',options,10);	% # nodes
			L	= 1;
			M	= min(M,N);			% # nodes
			if isempty(classifier)		% train from scratch
				mog{c}	= gmm(D,M,covtype);
				mog{c}	= gmminit(mog{c},y,gmm_o);
				fdist2  = lans_dist(mog{c}.centres','-metric Euclidean2') + eye(M)*realmax;
				nnd   = max(median(min(fdist2)),1);
				mog{c}.covars	= nnd*ones(size(mog{c}.covars));
%				mog{c}.covars	= ones(size(mog{c}.covars));
				% safest bet is to use unit covariance,
				% otherwise curse of dimensionality sets in
			end
			[mog{c},dum]	= gmmem(mog{c},y,gmm_o);
			like		= dum(8);
		end	% for c
		classifier	= mog;
	case 'knn',
		% classifier comprises the whole labeled training dataset
		% classifier.{in,out}
		classifier	= data;
		M		= size(data.in,2);
		L		= 1;
	case 'pps'
		plotpps	= paraget('-plotpps',options,0);
		plotdata= paraget('-plotdata',options,0);
		trgdata	= lans_group(pdata{1});
		for c=1:C
			if showtext
				manifold= paraget('-manifold',options);
				pstr	= sprintf('PPS-%sD:Training class %s (%d)',manifold(1),num2str(trgdata{c}.out(:,1)),c);
				disp(pstr);
			end
			y	= trgdata{c}.in;
			if isempty(classifier)		% train from scratch
				pps{c}	= lans_ppsinit(y,options,state);
			end
			if plotpps
				figure(plotpps);
				if plotdata
					lans_ppsplot(pps{c},'r','',y,'k.');
				else
					lans_ppsplot(pps{c},'r-','-hold 1');
				end
			end
			[pps{c},like]	= lans_pps(pps{c},y,options);
%			pps{c}		= lans_ppsclean(pps{c});
			M		= pps{c}.M;
			L		= pps{c}.L;

		end	% for c
		classifier	= pps;
	end

	%_____	forward pass thru trained network
	if forward
		options	= paraset('-mode','test',options);
		for p=1:3
			if ~isempty(pdata{p})
				if ~isempty(pdata{p}.in)
					result{p}	= lans_classifier(pdata{p},classifier,options);
				else
					result{p}	= [];
				end
			else
				result{p}	= [];
			end
		end
	else
		result.M	= M;
		result.L	= L;
	end	% if forward
else
	%________________________________________	test
	data	= pdata;
	y	= data.in;
	N	= size(data.out,2);		% # test samples
	metric	= [];
	ridx	= [];

	switch lower(model)
	case 'gmm'
		mog	= classifier;
		C	= length(mog);
		switch lower(class)	% classification method for PPS model
		case '1nn'
			for c=1:C
				lf		= mog{c}.centres';
				metric(c,:)	= min(lans_dist(lf,y));
			end
			% add rejection threshold
			[dum,idx]		= min(metric);
			cout			= idx;
		case 'ml'
	
		end	% switch lower(class)
		
		
	case 'knn'
		k	= str2num(class(1:end-2));
		k_op	= paraset('-knn',k,'');
		C	= length(lans_finduniq(classifier.out));

		[D,Nc]	= size(classifier.in);
		[D,Nt]	= size(y);

		if Nt*Nc*D>KNN_THRESH		% contain matrix size
			npart	= ceil(Nt*Nc*D/KNN_THRESH);	% # parts
			tsize	= ceil(Nt/npart);	
		else
			npart	= 1;
			tsize	= Nt;
		end
		if showtext
			str	= sprintf('knn: testing on %d parts of %d samples each',npart,tsize);
			disp(str);
		end

		partnum	= 0;
		ballot	= zeros(C,Nt);
		%	split test data into nparts of Nt/npart if too big
		while partnum<npart
			sidx	= partnum*tsize+1;
			eidx	= min((partnum+1)*tsize,Nt);
			Ntp	= eidx-sidx+1;	% # test samples in this part
			metric	= zeros(Nc,Ntp);
			metric	= lans_dist(classifier.in,y(:,sidx:eidx),'-metric euclidean2');
			partnum	= partnum + 1;

			c	= 1;
			chunk1	= (sidx-1)*C;
			for k1=1:k
				% winners = nearest training sample
				[v,winners]	= min(metric);
				removeit	= winners+Nc*(0:(Ntp-1));
				metric(removeit)= realmax;
				widx		= classifier.out(:,winners)+(C*(0:(Ntp-1)))+chunk1;
				ballot(widx)	= ballot(widx)+1;	
				[dum,cout(c,:)]	= max(ballot);
				c		= c+1;
				if showtext
					pstr	= sprintf('[Part %d of %d] Testing with %d-NN',partnum,npart,k1);
					disp(pstr);
				end
			end
		end	% while

	case 'pps'
		pps	= classifier;
		C	= length(pps);
		switch lower(class)	% classification method for PPS model
		case '1nn'
			for c=1:C
				lf		= lans_md2lin(pps{c}.f);
				metric(c,:)	= min(lans_dist(lf,y));
			end
			% add rejection threshold
			[dum,idx]		= min(metric);
			cout			= idx;
		case 'grid'
			options		= paraset('-proj','grid',options);
			for c=1:C
				[py,px,result]	= lans_ppsproj(pps{c},y,options);
				metric(c,:)	= result{1};
			end
			% add rejection threshold
			[dum,idx]		= min(metric);
			cout			= idx;
		case 'tri'
			options		= paraset('-proj','tri',options);
			for c=1:C
				[py,px,result]	= lans_ppsproj(pps{c},y,options);
				metric(c,:)	= result{1};
			end
			% add rejection threshold
			[dum,idx]		= min(metric);
			cout			= idx;
		case 'nnl'

		case 'ml'
			for c=1:C
				for n=1:N
					[R,like]= lans_ppspost(pps{c},y(:,n));
					% points too far from class c ignored
					if abs(sum(R)-1)<TOL_ML
						metric(c,n)	= like;
					else
						metric(c,n)	= -realmax;
					end
				end
			end	% for n
			[values,idx]	= max(metric);
			cout		= idx;
		end	% switch lower(class)
	end	% switch lower(model)

	nk		= size(cout,1);
	result.out	= cout;
	%__________	confusion matrix/misclassified samples' index
	rejectidx	= [];
	if confuse
		if nk==1
			[result.confuse,result.confuseidx]	= lans_confuse(cout,data.out,1:C);
		else
			for k=1:nk
			[result.confuse{k},result.confuseidx{k}]	= lans_confuse(cout(k,:),data.out,1:C);
			end
		end
	else
		result.confuse	= [];
		result.confuseidx=[];
	end	% if confuse
	% find mismatch
	trumat	= ones(nk,1)*data.out;
	[r,c]	= find(cout~=trumat);
	for k=1:nk
		if isempty(r)
			result.cerr(k)	= 0;
		else
			result.cerr(k)	= 100*length(find(r==k))/N;	
		end
	end

	result.rejectidx	= rejectidx;
	result.rejectrate	= 100*length(rejectidx)/N;
end
%__________ REGULAR ends _______________________________________________________
else
%__________ DEMO _______________________________________________________________
clf;clc;
disp('running lans_classifier.m in demo mode on iris.lans');
%dbtype lans_classifier 157:179;

knn_o		= '-model 1nn';
pps_o		= '-model pps -class 1nn -init manifold -manifold 3sphere -alpha 1 -center 0 -debug 0 -L1 5 -Lfac 2 -Lratio 0 -regularize .01 -eint 45 -rint 60';
iter_o		= '-mode train -confuse 1 -iter 20';

options		= [pps_o ' ' iter_o];
%options		= [knn_o];

ldata		= lans_load('iris');
ldata.in	= lans_stand(ldata.in);

pdata			= lans_partition(ldata,[.8 0 .2],'-part stratify');
[result,classifier]	= lans_classifier(pdata,[],options);
h			= lans_plot(lans_group(pdata{1}),'-legend 1 -colors uniform');

badclass		= find((result{3}.confuse-diag(diag(result{3}.confuse)))~=0);
badsamples.in		= [];
badsamples.out		= [];
for i=1:length(badclass)
	if ~isempty(badclass(i))
		badidx		= result{3}.confuseidx{badclass(i)};
		badsamples.in	= [badsamples.in pdata{3}.in(:,badidx)];
		badsamples.out	= [badsamples.out pdata{3}.out(:,badidx)];
	end
end

te_confuse	= result{3}.confuse
Error		= 1-sum(diag(te_confuse))./sum(sum(te_confuse));
if ~isempty(badsamples.in)
	b		= lans_plotmd(badsamples.in,'ko-','-hold 1');
	legend([h;b],char({'1','2','3','misclassified'}),2);
	title(sprintf('Classification Error = %2.2f %s',Error*100,'%'));
end

%__________ DEMO ends __________________________________________________________
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -