⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 looknn_label.m

📁 matlab 源码 KNN classification
💻 M
字号:
function [misclassify, computed] = looknn(sampledata, k, option)
%LOOKNN Leave-one-out error (misclassification count) of KNN
%
%	Usage:
%	[MISCLASSIFY, INDEX, NEAREST_SAMPLE_INDEX] = LOOKNN(SAMPLEDATA, K, OPTION)
%
%	MISCLASSIFY: No. of misclassification points
%	INDEX: Index of misclassified points
%	NEAREST_SAMPLE_INDEX: Nearest sample index of the misclassified points
%	K: The "k" in k-nearest neighbor rule
%	SAMPLEDATA: Sample data set, with the last column being
%		the desired label
%	OPTION = 0 for small data set (vectorized operation based)
%	       = 1 for large data set (for-loop based)
%
%	Type "looknn" to see a simple example.

%	Roger Jang, 970628, 990613, 011229

if nargin==0, selfdemo; return; end
if nargin<3,
	if size(sampledata, 1) <= 1500,
		option=0;	% Small data set, use vectorized operation
	else
		option=1;	% Large data set, use for-loop operation
	end
end
if nargin<2, k=1; end

featureNum = size(sampledata, 2)-1;
sampleNum = size(sampledata, 1);
input = sampledata(:, 1:featureNum);
desired = sampledata(:, featureNum+1);
classLabel = countele(desired);
classNum = length(classLabel);
computed = zeros(size(desired));
nearestSampleIndex = zeros(size(desired));

if option == 0,	% vectorized operation; suitable for small dataset
	distmat = vecdist(input);
	distmat(1:(sampleNum+1):sampleNum^2) = inf;	% Set diagonal elements to inf

	% The following was swiped from knn.m
	[junk, nearestSampleIndex] = sort(distmat, 1);
	% knnmat(i,j) = class of i-th nearest point of j-th input vector
	knnmat = reshape(desired(nearestSampleIndex(1:k,:)), k, sampleNum);
	% classCount(i,j) = count of class-i points within j-th input vector's neighborhood
	classCount = zeros(classNum, sampleNum);
	for i = 1:sampleNum,
		[sortedElement, elementCount] = countele(knnmat(:,i));
		classCount(sortedElement, i) = elementCount;
	end
	[junk, computed] = max(classCount, [], 1);
	computed = computed';
else		% for-loop version; suitable for large dataset
	nearestSampleIndex = zeros(1, sampleNum);
	for i = 1:sampleNum,
		looData = sampledata;
		looData(i, :) = [];
		[computed(i), junk, tmp] = knn(looData, sampledata(i, :), k);
		nearestSampleIndex(i) = tmp(1);
		if nearestSampleIndex(i)>=i,
			nearestSampleIndex(i)=nearestSampleIndex(i)+1;
		end
		% ====== on-line display
	%	fprintf('%g/%g ---> ', i, sampleNum);
	%	if computed(i)==desired(i),
	%		fprintf('correct\n'); 
	%	else
	%		fprintf('wrong\n'); 
	%	end
	end
end
index = find(desired~=computed);
misclassify = length(index);
nearestIndex = nearestSampleIndex(1, index)';

% ====== Self demo ======
function selfdemo
	% create a data set sampledata
	data_n = 50;
	k = 8;
	c1 = [0.125 0.25];
	data1 = randn(data_n, 2)/k + ones(data_n, 1)*c1;
	out1 = 1*ones(data_n, 1);
	c2 = [0.625 0.25];
	data2 = randn(data_n, 2)/k + ones(data_n, 1)*c2;
	out2 = 2*ones(data_n, 1);
	c3 = [0.375 0.75];
	data3 = randn(data_n, 2)/k + ones(data_n, 1)*c3;
	out3 = 3*ones(data_n, 1);
	c4 = [0.875 0.75];
	data4 = randn(data_n, 2)/k + ones(data_n, 1)*c4;
	out4 = 4*ones(data_n, 1);
	data = [data1; data2; data3; data4];
	sampledata = [data [out1; out2; out3; out4]];
	index = (min(data')>0) & (max(data')<1);
	sampledata(find(index == 0), :) = [];
	% plot the data set
	delete(gca);
	data_n = size(sampledata, 1);
	featureNum = size(sampledata, 2)-1;
	label = sampledata(:, featureNum+1); 
	index1 = find(label==1); 
	index2 = find(label==2); 
	index3 = find(label==3); 
	index4 = find(label==4); 
	figure;
	colordef(gcf, 'black');
	for i=1:length(index1),
		line(sampledata(index1(i), 1), sampledata(index1(i), 2), ...
			'marker', 'o', 'color', 'y');
	end
	for i=1:length(index2),
		line(sampledata(index2(i), 1), sampledata(index2(i), 2), ...
			'marker', 'o', 'color', 'c');
	end
	for i=1:length(index3),
		line(sampledata(index3(i), 1), sampledata(index3(i), 2), ...
			'marker', 'o', 'color', 'm');
	end
	for i=1:length(index4),
		line(sampledata(index4(i), 1), sampledata(index4(i), 2), ...
			'marker', 'o', 'color', 'g');
	end
	set(gca, 'box', 'on');
	axis image;
	% find leave-one-out error points
	k = 1;
	option = 0;	% small dataset
	[misclassify, index] = feval(mfilename, sampledata, k, option);
	% display these points
	for i=1:length(index),
		line(sampledata(index(i), 1), sampledata(index(i), 2), ...
			'marker', 'x', 'color', 'w');
	end
	titleString = sprintf('There are %d leave-one-out error points denoted by "x".', length(index));
	title(titleString);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -