⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crossval.m

📁 这是我找到的一个模式识别工具箱
💻 M
字号:
%CROSSVAL Error/performance estimation by cross validation (rotation)% %   [ERR,CERR,NLAB_OUT] = CROSSVAL(A,CLASSF,N,1,TESTFUN)%   [ERR,STDS]          = CROSSVAL(A,CLASSF,N,NREP,TESTFUN)%   R                   = CROSSVAL(A,[],N,0)%% INPUT%   A          Input dataset%   CLASSF     The untrained classifier to be tested.%   N          Number of dataset divisions (default: N==number of%              samples, leave-one-out)%   NREP       Number of repetitions (default: 1)%   TESTFUN    Mapping,evaluation function (default classification error)%% OUTPUT%   ERR        Average test error or performance weighted by class priors.%   CERR       Unweighted test errors or performances per class%   NLAB_OUT   Assigned numeric labels%   STDS       Standard deviation over the repetitions.%   R          Index array with rotation set%% DESCRIPTION% Cross validation estimation of the error or performance (defined by TESTFUN)% of the untrained classifier CLASSF using the dataset A. The set is randomly% permutated and divided in N (almost) equally sized parts. The classifier% is trained on N-1 parts and the remaining part is used for testing.  This% is rotated over all parts. ERR is their weighted avarage over the class% priors. CERR are the class error frequencies.  A and/or CLASSF may be% cell arrays of datasets and classifiers. In that case ERR is an array% with on position ERR(i,j) the error or performance of classifier j for% dataset i. In this mode CERR and NLAB_OUT are returned in cell arrays.%% In case NREP > 1 the mean error(s) over the repetitions is returned in ERR% and the standard deviations in the observed errors in STDS.%% In case NREP == 0 an index array is returned pointing to a fold for every% object. No training or testing is done. This is useful for handling% training and testing outside CROSSVAL.% % See also DATASETS, MAPPINGS, TESTC% Copyright: D.M.J. Tax, R.P.W. Duin, r.p.w.duin@prtools.org% Faculty EWI, Delft University of Technology% P.O. Box 5031, 2600 GA Delft, The Netherlands% $Id: crossval.m,v 1.10 2008/07/03 09:05:50 duin Exp $function [err,cerr,nlabout] = crossval(data,classf,n,nrep,testf,fid)	prtrace(mfilename);  if nargin < 6, fid = []; end  if nargin < 5, testf = []; end	if nargin < 4, nrep = []; end	if nargin < 3, n = []; end    if ~ismapping(testf) & isempty(fid) % correct for old call without testf    fid = testf; testf = [];  end		if iscell(data) % generate prior warnings now		for j=1:length(data)			data{j} = setprior(data{j},getprior(data{j}));		end	else		data = setprior(data,getprior(data));	end		warnlevel = prwarning;	prwarning(0);		% datasets or classifiers are cell arrays	if iscell(classf) | iscell(data)		seed = rand('state');		if ~iscell(classf), classf = {classf}; end		if ~iscell(data), data = {data}; end		if isdataset(classf{1}) & ismapping(data{1}) % correct for old order			dd = data; data = classf; classf = dd;		end		numc = length(classf);		numd = length(data);		cerr = cell(numd,numc);		nlab_out = cell(numd,numc);		s1 = sprintf('crossval: %i classifiers: ',numc);		prwaitbar(numc,s1);		e = zeros(numd,numc);    		for jc = 1:numc			prwaitbar(numc,jc,[s1 getname(classf{jc})]);			s2 = sprintf('crossval: %i datasets: ',numd);			prwaitbar(numd,s2);			 			for jd = 1:numd				prwaitbar(numd,jd,[s2 getname(data{jd})]);				rand('state',seed);				[ee,cc,nn] = feval(mfilename,data{jd},classf{jc},n,nrep,testf);				e(jd,jc) = ee;				cerr(jd,jc) = {cc};				nlabout(jd,jc) = {nn};			end			prwaitbar(0);		end		prwaitbar(0);		if nrep > 1, cerr = cell2mat(cerr); nlabout = NaN; end					if nargout == 0			fprintf('\n  %i-fold cross validation result for',n);			disperror(data,classf,e);		end		if nargout > 0, err  = e;  end	else				data = setprior(data,getprior(data)); % just to generate warning when needed				if isempty(nrep), nrep = 1; end		if nrep > 1						s3 = sprintf('crossval: %i repetitions: ',nrep);			prwaitbar(nrep,s3);			ee = zeros(1,nrep);			for j=1:nrep				prwaitbar(nrep,j,[s3 int2str(j)]);				[ee(j),ss,nlabout] = feval(mfilename,data,classf,n,1,testf);			end			prwaitbar(0);			err = mean(ee);			cerr = std(ee);			nlabout = NaN;			prwarning(warnlevel);			return		end		if isdataset(classf) & ismapping(data) % correct for old order			dd = data; data = classf; classf = dd;		end		isdataset(data);		if nrep > 0, ismapping(classf); end		[m,k,c] = getsize(data);		lab = getlab(data);		if isempty(n), n = m; end		if n == m & ~isempty(testf)			error('No external error routine allowed in case of leave-one-out cross validation')		end		if n > m			warning('Number of batches too large: reset to leave-one-out')			n = m;		elseif n < 1			error('Wrong size for number of cross-validation batches')		end		if (nrep > 0 & ~isuntrained(classf))			error('Classifier should be untrained')		end		J = randperm(m);		N = classsizes(data);		% attempt to find an more equal distribution over the classes				if all(N >= n)			K = zeros(1,m);			for i = 1:length(N)				L = findnlab(data(J,:),i);				M = mod(0:N(i)-1,n)+1;				K(L) = M;			end		else			K = mod(1:m,n)+1;		end		nlabout = zeros(m,1);				if nrep == 0 % trick to return rotation set			err = zeros(1,m);			err(J) = K;			prwarning(warnlevel);			return		end		    f = zeros(1,n);		s4 = sprintf('crossval, %i-folds: ',n);		prwaitbar(n,s4);				for i = 1:n			prwaitbar(n,i,[s4 int2str(i)]);			OUT = find(K==i);			JOUT=J(OUT);			JIN = J; JIN(OUT) = [];			train_data = data(JIN,:);			%train_data = setprior(train_data,getprior(train_data));			w = train_data*classf; % training			                       % testing			testres = data(JOUT,:)*w;			if ~isempty(testf)      	f(i) = testres*testf;			end			testout = testres*maxc;			[mx,nlabout(JOUT)] = max(+testout,[],2);                              % nlabout contains class assignments		end		prwaitbar(0);		%correct for different order of lablist and labels assigned by		%classifier. Assume this is constant over folds.		nlist = renumlab(getfeatlab(testout),getlablist(data));		nlabout = nlist(nlabout);		if isempty(testf)			e = zeros(1,c);			for j=1:c				J = findnlab(data,j);				e(j) = sum(nlabout(J)~=j)/length(J);    	end			e = e*getprior(data,0)';		else			e = mean(f); % f already weighted by class priors inside testf		end				if nargout > 0			err  = e;			if isempty(testf)				cerr = f;			else				cerr = [];				nlabout = [];			end		else			disp([num2str(n) '-fold cross validation error on ' num2str(size(data,1)) ' objects: ' num2str(e)])		end	end		prwarning(warnlevel);	return

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -