⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 pnncv.m

📁 支持向量机(SVM)实现的分类算法源码(matlab)
💻 M
字号:
function [err,raw,prob,sse,distmat] = pnncv(train,ncls,clsinfo,clsize,smooth,choice,prntopt,distmat)
% PNNCV  probabalistic neural network cross validation
% Author:	Ron Shaffer
% Revisions:	4/10/96 Version 1.0 Original code (based on PROBNN 1.4)
% 		4/15/96 Version 1.1 Added ability to pass in distance matrix
%		4/16/96 Version 1.2 Added the ability to suppress printout
%		4/25/95 Version 1.3 Compute continuous error criterion 
%				    using sum-of-squared-errors
%		4/26/95 Version 1.4 Removed possibility of sse returning NaN
%
% [err,raw,prob,sse,distmat] = pnncv(train,ncls,clsinfo,clsize,smooth,choice,prntopt,distmat)
%
% err		number of misclassifed patterns in cross-validation
% raw		raw pnn outputs for cross-validation procedure
% prob:		Bayes posterior probabilities
% sse:		sum of squared errors
% distmat	matrix of distance values 
% train:	training set patterns (number of pattern x number of sensors) 
% ncls:		number of classes (number of outputs for PNN)
% clsinfo:	vector containing the classication of each pattern in training
% clsize:	vector of class sizes
% smooth	smoothing factor
% choice	choice of distance measure 1 = d.p 2 = euclidean
% prntopt	printout control (1 = full) (0 = min) [optional]
% distmat	input distance matrix [optional]
% NOTE:  Use this code at your own risk because the author assumes no liability! 

%
% Note: this m-file makes use of routines from the PLS_Toolbox package from Eigenvector
% technologies and the neural networks toolbox from MATLAB.
%

%
% set constants
%
err = 0;
sse = 0;
misclassed(1:ncls) = zeros(size(1:ncls));
[npat_t,ndim] = size(train);
nhcel = npat_t;
smooth_sqr = smooth * smooth;
%
% if only 5 arguments are passed in then suppress printout
%
if (nargin == 6)
	prntopt = 0;
end
%
% move training data to hidden units (i.e., training) and normalize.
% Normalization method based on Mark Beale's normr routine from the
% neural network toolbox
%
hcel = sqrt(ones./(sum((train.*train)')))'*ones(1,ndim).*train;
%
% if nargin is less than 7 then distance matrix must be computed
% otherwise skip this time-consuming step
%
if (nargin <= 7)
%
% 	compute distance matrix using dot product calculation (much faster method!)
% 	or euclidean distance calculation
%
	if choice == 1
		if prntopt == 1
			fprintf('Computing distance matrix using Dot Product calculation \n');
		end
		distmat = 1-(hcel * hcel');
	else
		if prntopt == 1
			fprintf('Computing distance matrix using Euclidean Distance calculation \n');
		end
		for i = 1:nhcel
			for j = 1:i
				distmat(i,j) = sum((hcel(j,1:ndim) - hcel(i,1:ndim)).^2);
				distmat(j,i) = distmat(i,j);
			end
		end
	end
end
%
% now perform cross-validation
%
for i = 1:npat_t
%
	distvect = distmat(i,:)';
	distvect = delsamps(distvect,i);
	weight = exp((-distvect)/smooth_sqr);
%
	newclsinfo = delsamps(clsinfo,i);
%
%	summation for each output layer
%
	output(1:ncls) = zeros(size(1:ncls));
	for j = 1:nhcel-1
		output(clsinfo(j)) = output(clsinfo(j)) + weight(j);
	end
%
%	output with highest probability is the winner
%
	raw(i,1:ncls) = output(1:ncls);
%
% 	Compute mean output for each class by dividing by class size
% 	This calculation corrects for unequal class sizes.
%
	cvclsize = clsize;
	cvclsize(clsinfo(i)) = clsize(clsinfo(i)) - 1;
	output(1:ncls) = output(1:ncls) ./ cvclsize;
	[junk,winner(i)] = max(output);
	sumout = sum(output);
	prob(i,1:ncls) = output(1:ncls)./sumout;
%
%	compute sum of squared error (Masters's book pages 197-201)
%
	sse = sse + (((1-prob(i,clsinfo(i)))^2) + (sum((prob(i,:).^2)) - prob(i,clsinfo(i))^2));
%
%	collect misclassified patterns
%
	if winner(i) ~= clsinfo(i)
		err = err + 1;
		misclassed(clsinfo(i)) = misclassed(clsinfo(i)) + 1;
	end
end
%
% Print out results and exit
%
misclassed(1:ncls) = 100 .* (clsize(1:ncls) - misclassed(1:ncls)) ./ clsize(1:ncls);
if prntopt == 1
	for i = 1:ncls
		fprintf('Class  %d  Percentage Correct  %7.4f \n',i,misclassed(i));
	end
end
overallerr = 100*(err/npat_t);
if isnan(sse) == 1
	sse = 9999;
end
if prntopt == 1
	fprintf('Overall Error %7.4f \n',overallerr);
	fprintf('Sum of Squared Errors %7.4f \n', sse);
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -