📄 knn_classify.m
字号:
function [Y_compute, Y_prob] = kNN_classify(para, X_train, Y_train, X_test, Y_test, num_class)
Y_compute = zeros(size(Y_test)); Y_prob = zeros(size(Y_test));
if (isempty(X_train)),
fprintf('Error: The training set is empty!\n');
return;
end;
class_set = GetClassSet(Y_train);
p = str2num(char(ParseParameter(para, {'-k';'-d'}, {'1';'0'})));
k = p(1);
disttype = p(2);
num_test = size(X_test, 1);
num_train = size(X_train, 1);
num_feature = size(X_test, 2);
Y_compute = zeros(num_test, 1);
Y_prob = zeros(num_test, 1);
fprintf('Iter: %4d', 0);
X_train_sqr = sqrt(sum(X_train .* X_train, 2));
X_test_sqr = sqrt(sum(X_test .* X_test, 2));
for i = 1:num_test
sumDistance = zeros(num_train, 1);
%for j = 1:num_train
% sumDistance(j) = vecdist(X_train(j, :), X_test(i, :), disttype);
%end;
sumDistance = vecdist(X_train, X_test(i, :), disttype, X_train_sqr, X_test_sqr(i));
if rem(i, 100) == 0
fprintf('\b\b\b\b\b\b\b\b\b\bIter: %4d', i);
end
[sortDis, Index] = sort(sumDistance);
n = hist(Y_train(Index(1:k)), class_set);
[junk, index] = max(n);
Y_compute(i) = class_set(index);
Y_prob(i) = sortDis(1);
end;
fprintf('\n');
% Y_prob = Y_compute;
function dist = vecdist(X_train_vec, X_test_vec, disttype, X_train_sqr_vec, X_test_sqr)
switch(disttype)
case 0
X_diff = (X_train_vec - repmat(X_test_vec, size(X_train_vec, 1), 1));
dist = sum(X_diff .* X_diff, 2);
case 1
plusdist = (X_train_vec + repmat(X_test_vec, size(X_train_vec, 1), 1));
plusdist = plusdist + (plusdist == 0) * 1e-8;
minusdist = (X_train_vec - repmat(X_test_vec, size(X_train_vec, 1), 1));
dist = sum(minusdist .* minusdist ./ plusdist, 2); % chi2 distance
case 2
dist = sum((X_train_vec .* repmat(X_test_vec, size(X_train_vec, 1), 1)), 2);
dist = (dist ./ X_train_sqr_vec) / X_test_sqr;
dist = -dist; % cosine similarity, make it a distance
end;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -