📄 looknn_label.m
字号:
function [misclassify, computed] = looknn(sampledata, k, option)
%LOOKNN Leave-one-out error (misclassification count) of KNN
%
% Usage:
% [MISCLASSIFY, INDEX, NEAREST_SAMPLE_INDEX] = LOOKNN(SAMPLEDATA, K, OPTION)
%
% MISCLASSIFY: No. of misclassification points
% INDEX: Index of misclassified points
% NEAREST_SAMPLE_INDEX: Nearest sample index of the misclassified points
% K: The "k" in k-nearest neighbor rule
% SAMPLEDATA: Sample data set, with the last column being
% the desired label
% OPTION = 0 for small data set (vectorized operation based)
% = 1 for large data set (for-loop based)
%
% Type "looknn" to see a simple example.
% Roger Jang, 970628, 990613, 011229
if nargin==0, selfdemo; return; end
if nargin<3,
if size(sampledata, 1) <= 1500,
option=0; % Small data set, use vectorized operation
else
option=1; % Large data set, use for-loop operation
end
end
if nargin<2, k=1; end
featureNum = size(sampledata, 2)-1;
sampleNum = size(sampledata, 1);
input = sampledata(:, 1:featureNum);
desired = sampledata(:, featureNum+1);
classLabel = countele(desired);
classNum = length(classLabel);
computed = zeros(size(desired));
nearestSampleIndex = zeros(size(desired));
if option == 0, % vectorized operation; suitable for small dataset
distmat = vecdist(input);
distmat(1:(sampleNum+1):sampleNum^2) = inf; % Set diagonal elements to inf
% The following was swiped from knn.m
[junk, nearestSampleIndex] = sort(distmat, 1);
% knnmat(i,j) = class of i-th nearest point of j-th input vector
knnmat = reshape(desired(nearestSampleIndex(1:k,:)), k, sampleNum);
% classCount(i,j) = count of class-i points within j-th input vector's neighborhood
classCount = zeros(classNum, sampleNum);
for i = 1:sampleNum,
[sortedElement, elementCount] = countele(knnmat(:,i));
classCount(sortedElement, i) = elementCount;
end
[junk, computed] = max(classCount, [], 1);
computed = computed';
else % for-loop version; suitable for large dataset
nearestSampleIndex = zeros(1, sampleNum);
for i = 1:sampleNum,
looData = sampledata;
looData(i, :) = [];
[computed(i), junk, tmp] = knn(looData, sampledata(i, :), k);
nearestSampleIndex(i) = tmp(1);
if nearestSampleIndex(i)>=i,
nearestSampleIndex(i)=nearestSampleIndex(i)+1;
end
% ====== on-line display
% fprintf('%g/%g ---> ', i, sampleNum);
% if computed(i)==desired(i),
% fprintf('correct\n');
% else
% fprintf('wrong\n');
% end
end
end
index = find(desired~=computed);
misclassify = length(index);
nearestIndex = nearestSampleIndex(1, index)';
% ====== Self demo ======
function selfdemo
% create a data set sampledata
data_n = 50;
k = 8;
c1 = [0.125 0.25];
data1 = randn(data_n, 2)/k + ones(data_n, 1)*c1;
out1 = 1*ones(data_n, 1);
c2 = [0.625 0.25];
data2 = randn(data_n, 2)/k + ones(data_n, 1)*c2;
out2 = 2*ones(data_n, 1);
c3 = [0.375 0.75];
data3 = randn(data_n, 2)/k + ones(data_n, 1)*c3;
out3 = 3*ones(data_n, 1);
c4 = [0.875 0.75];
data4 = randn(data_n, 2)/k + ones(data_n, 1)*c4;
out4 = 4*ones(data_n, 1);
data = [data1; data2; data3; data4];
sampledata = [data [out1; out2; out3; out4]];
index = (min(data')>0) & (max(data')<1);
sampledata(find(index == 0), :) = [];
% plot the data set
delete(gca);
data_n = size(sampledata, 1);
featureNum = size(sampledata, 2)-1;
label = sampledata(:, featureNum+1);
index1 = find(label==1);
index2 = find(label==2);
index3 = find(label==3);
index4 = find(label==4);
figure;
colordef(gcf, 'black');
for i=1:length(index1),
line(sampledata(index1(i), 1), sampledata(index1(i), 2), ...
'marker', 'o', 'color', 'y');
end
for i=1:length(index2),
line(sampledata(index2(i), 1), sampledata(index2(i), 2), ...
'marker', 'o', 'color', 'c');
end
for i=1:length(index3),
line(sampledata(index3(i), 1), sampledata(index3(i), 2), ...
'marker', 'o', 'color', 'm');
end
for i=1:length(index4),
line(sampledata(index4(i), 1), sampledata(index4(i), 2), ...
'marker', 'o', 'color', 'g');
end
set(gca, 'box', 'on');
axis image;
% find leave-one-out error points
k = 1;
option = 0; % small dataset
[misclassify, index] = feval(mfilename, sampledata, k, option);
% display these points
for i=1:length(index),
line(sampledata(index(i), 1), sampledata(index(i), 2), ...
'marker', 'x', 'color', 'w');
end
titleString = sprintf('There are %d leave-one-out error points denoted by "x".', length(index));
title(titleString);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -