⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 knn1.m

📁 k nearest neiborhood in macine learning
💻 M
字号:
function [eachClass, ensembleClass, nearestSampleIndex, knnmat] = ...
	knn(sampledata, testdata, k)
% KNN	K-nearest neighbor rule for classification
%	Usage:
%	[EACH_CLASS, ENSEMBLE_CLASS, NEAREST_SAMPLE_INDEX] = KNN(SAMPLE, INPUT, K)
%
%	SAMPLE: Sample data set (The last column is the desired class of each
%		sample vector. The values of the last column are assumed to
%		be integers ranging from 1 to N.)
%	INPUT: Test input matrix (each row is an input vector)
%	K: the "k" in "k nearest neighbor"
%	EACH_CLASS: A vector denoting the KNN output class of each input vector
%	NEAREST_SAMPLE_INDEX: Index of SAMPLE that are closest to INPUT
%	ENSEMBLE_CLASS: A scalar denoting the KNN output class of all input
%		vectors that are assumed to be of the same class
%		(A voting mechanism is invoked to determine a scalar value
%		between 1 and N.)

%	Roger Jang, 9703xx, 990613, 991215 

if nargin == 0, selfdemo; return; end
if nargin < 3, k = 1;, end

featureNum = size(sampledata,2)-1;
sampleInput = sampledata(:, 1:featureNum);
sampleOutput = sampledata(:, featureNum+1);
classLabel = countele(sampleOutput);	% possible output class
classNum = length(classLabel);
testNum = size(testdata, 1);		% no. of test input vectors
testInput = testdata(:, 1:featureNum);	% strip out class info, if any

% Euclidean distance matrix between sampleInput and testInput
distmat = vecdist(sampleInput, testInput);

% knnmat(i,j) = class of i-th nearest point of j-th test input vector
% (The size of knnmat is k by testNum.)
[junk, nearestSampleIndex] = sort(distmat, 1);
% The following "reshape" is necessary if k == 1.
knnmat = reshape(sampleOutput(nearestSampleIndex(1:k,:)), k, testNum);

% class_count(i,j) = number of class-i points in j-th test input's neighborhood
% (The size of class_count is classNum by testNum.)
class_count = zeros(classNum, testNum);
for i = 1:testNum,
	[sortedElement, elementCount] = countele(knnmat(:,i));
	class_count(sortedElement, i) = elementCount;
end

[junk, ensembleClass] = max(sum(class_count, 2));
[junk, eachClass] = max(class_count, [], 1);
eachClass = eachClass';

function selfdemo
	load iris.dat
	dataNum = size(iris, 1);
	design = iris(1:2:dataNum, :);
	test   = iris(2:2:dataNum, :);
	design_n = size(design, 1);
	testNum   = size(test, 1);
	fprintf('Use of KNN for Iris data:\n');
	fprintf('\tSize of design set (odd-indexed data)= %d\n', design_n);
	fprintf('\tSize of test set (even-indexed data) = %d\n', testNum);
	fprintf('\tRecognition rates as K varies:\n');
	max_k = 15;
	for k = 1:max_k,
		computed = feval(mfilename, design, test, k);
		correct_count = sum(test(:, end) == computed);
		recog(k) = correct_count/testNum;
		fprintf('\t%d-NNR ===> 1-%d/%d = %.2f%%.\n', ...
			k, testNum-correct_count, testNum, recog(k)*100);
	end
	plot(1:max_k, recog*100, 'b-o');
	grid on;
	title('Recognition rates of Iris data using K-NNR');
	xlabel('K');
	ylabel('Recognition rates (%)');

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -