⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 knnr.m

📁 用matlab编写的源程序
💻 M
字号:
function [computedOutput, combinedComputedOutput, nearestIndex, knnrMat] = knnr(DS, TS, k)
% knnr: K-nearest neighbor rule for classification
%	Usage:
%	[computedOutput, combinedComputedOutput, nearestIndex, knnrMat] = knnr(DS, TS, k)
%
%	DS: design set
%		DS.input: input part (each column is a feature vector)
%		DS.output: output part (integers ranging from 1 to N.)
%	TS: test set
%		TS.input: input part (each column is a feature vector)
%	k: the "k" in "k nearest neighbor"
%	computedOutput: output vector by KNNR
%	combinedComputedOutput: a single output by KNNR, assuming all TS are of the same class
%		(A voting mechanism is invoked to determine a scalar value between 1 and N.)
%	nearestIndex: Index of DS.input that are closest to TS.input
%	knnrMat(i,j) = class of i-th nearest point of j-th test input vector

%	Roger Jang, 19970331, 20040928

if nargin<1, selfdemo; return; end
if nargin<3, k=1; end

[dim, designNum]=size(DS.input);
classLabel=elementCount(DS.output);		% possible output class
classNum=length(classLabel);
[dim, testNum]=size(TS.input);

% Squared Euclidean distance matrix between sampleInput and testInput
distMat = pairwiseSqrDist(DS.input, TS.input);

% knnrMat(i,j) = class of i-th nearest point of j-th test input vector (size = k by testNum.)
[junk, nearestIndex] = sort(distMat, 1);
%knnrMat=DS.output(nearestIndex(1:k,:));	% This causes an error when k>1
knnrMat=reshape(DS.output(nearestIndex(1:k,:)), k, testNum);


% classCount(i,j) = number of class-i points in j-th test input's neighborhood
classCount = zeros(classNum, testNum);
for i=1:testNum,
	[sortedElement, elementCnt]=elementCount(knnrMat(:,i));
	classCount(sortedElement, i)=elementCnt;
end

[junk, combinedComputedOutput]=max(sum(classCount, 2));
[junk, computedOutput]=max(classCount, [], 1);

% ====== Self demo
function selfdemo
[DS, TS]=prData('iris');
designNum=size(DS.input, 2);
testNum  =size(TS.input, 2);
fprintf('Use of KNNR for Iris data:\n');
fprintf('\tSize of design set (odd-indexed data)= %d\n', designNum);
fprintf('\tSize of test set (even-indexed data) = %d\n', testNum);
fprintf('\tRecognition rates as K varies:\n');
kMax=15;
for k=1:kMax,
	computed=feval(mfilename, DS, TS, k);
	correctCount=sum(TS.output==computed);
	recog(k)=correctCount/testNum;
	fprintf('\t%d-NNR ===> 1-%d/%d = %.2f%%.\n', k, testNum-correctCount, testNum, recog(k)*100);
end
plot(1:kMax, recog*100, 'b-o'); grid on;
title('Recognition rates of Iris data using K-NNR');
xlabel('K'); ylabel('Recognition rates (%)');

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -