📄 nearest_neighbor_vc.m
字号:
% Learns classifier and classifies test set
% using k-NN rule
% Usage
% [trainError, testError, estTrainLabels, estTestLabels] = ...
% Nearest_Neighbor_VC(trainFeatures, trainLabels, testFeatures, testLabels)
% where
%
% Inputs:
% trainFeatures - the training set vectors, one vector per column
% trainLabels - the labels of the above
% algo_pars - Ignored
% testFeatures - test set, one column per vector
% testLabels - labels for test set%
% Outputs
% train_error - the error rate on the training set (one entry per
% class + total error)
% test_error - the error rate on the test set (one entry per class
% + total error)
function [trainError, testError, estTrainLabels, estTestLabels] = ...
Nearest_Neigbor_VC(trainFeatures, trainLabels, Knn, testFeatures, testLabels)
% Date Name Change
% 03/18/02 Vittorio Castelli Now works for any number of features,
% any number of classes.
% Added documentation
hm = findobj('Tag', 'Messages');
fprintf('%d-Nearest-Neighbor: Computing Training Set Error Rate\n',Knn);
if (isempty(hm)==0)
s = sprintf('%d-Nearest-Neighbor: Computing Training Set Error Rates', Knn);
set(hm,'String',s);
refresh;
end
[Nclasses, classes] = find_classes([trainLabels(:);testLabels(:)]); % Number of classes in labels
[Dim, Nsam] = size(trainFeatures);
OneD = ones(1,Nsam);
% computing the error on the training set
hatTrainLabels = zeros(size(trainLabels));
for sam = 1:Nsam,
% computes squared distances, sorts them
diff = trainFeatures(:,sam) * OneD;
diff = diff - trainFeatures;
diff = sum(diff.* diff);
[diff, indices] = sort(diff);
% takes care of the small sample problem
if (length(trainLabels) <= Knn)
k_nearest = trainLabels;
else
k_nearest = trainLabels(indices(1:Knn));
end
% finds the class with largest number of neighbors
bestClass = 0;
maxNN = 0;
for cl=1:Nclasses,
nrNN = sum(k_nearest == classes(cl));
if(nrNN > maxNN)
maxNN = nrNN; bestClass = cl;
end
end
hatTrainLabels(sam) = classes(bestClass);
end
% ================================================================
fprintf('%d-Nearest-Neighbor: Computing Test Set Error Rate\n', Knn);
if (isempty(hm)==0)
s = sprintf('%d-Nearest-Neighbor: Computing Test Set Error Rates',Knn);
set(hm,'String',s);
refresh;
end
% computing the test error
[Dim, Nsam] = size(testFeatures);
hatTestLabels = zeros(size(testLabels));
for sam = 1:Nsam,
diff = testFeatures(:,sam) * OneD;
diff = diff - trainFeatures;
diff = sum(diff.* diff);
[diff, indices] = sort(diff);
if (length(trainLabels) <= Knn)
k_nearest = trainLabels;
else
k_nearest = trainLabels(indices(1:Knn));
end
bestClass = 0;
maxNN = 0;
for cl=1:Nclasses,
nrNN = sum(k_nearest == classes(cl));
if(nrNN > maxNN)
maxNN = nrNN; bestClass = cl;
end
end
hatTestLabels(sam) = classes(bestClass);
end
trainError = computeError( classes, trainLabels, hatTrainLabels);
testError = computeError( classes, testLabels, hatTestLabels);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -