📄 knntest.m
字号:
%Test k-nearest neighbours
numExamples = 15;
numFeatures = 10;
noise = 1;
tol = 10^-5;
X = rand(numExamples, numFeatures);
X = centerData(X);
X = normalise(X);
q = rand(numFeatures, 1);
y = sign(X*q + noise*(rand(numExamples, 1)-0.5));
[trainX, trainY, testX, testY] = splitData(X, y, 2/3);
numTrainExamples = size(trainX, 1);
numTestExamples = size(testX, 1);
params.k = 3;
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, trainX, params);
if norm(trainInfo.predictedY - testInfo.predictedY) > tol
error('Incorrect predictions on training/test set');
end
if norm(trainInfo.rankings - trainInfo.rankings) > tol
error('Incorrect rankings on training/test set');
end
params.k = 1;
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, trainX, params);
if norm(trainInfo.predictedY - trainY) > tol
error('1-NN predicting incorrectly');
end
%Check if it catches error when try to use 2 labels
params.k = 2;
try
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, trainX, params);
catch
disp('Correctly identified problem of using even number of labels');
end
%Now test if the predictions are correct
params.k = 3;
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, testX, params);
predictedTestY = zeros(numTestExamples, 1);
rankings = zeros(numTestExamples, 1);
distances = zeros(numTrainExamples, 1);
for i=1:numTestExamples
for j=1:numTrainExamples
distances(j) = norm(trainX(j, :) - testX(i, :))^2;
end
[distances, indices] = sort(distances);
predictedTestY(i) = sign(sum(trainY(indices(1:3))));
rankings(i) = sum(exp(-distances(1:3)) .* trainY(indices(1:3)));
end
if norm(predictedTestY - testInfo.predictedY) > tol
error('Test labels predicted incorrectly');
end
if norm(rankings - testInfo.rankings) > tol
error('Test rankings predicted incorrectly');
end
%Test if the method can work with 0/1 labels too
trainY = (trainY+1)/2;
predictedTestY = (predictedTestY+1)/2;
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, testX, params);
if norm(predictedTestY - testInfo.predictedY) > tol
error('Test labels predicted incorrectly for 0/1 labels');
end
%Test that it can pick up the correct examples in the multi label case
numLabels = 3;
trainX2 = rand(numTrainExamples, numFeatures);
trainY2 = sign(rand(numTrainExamples, numLabels)-0.5);
testX2 = rand(numTestExamples, numFeatures);
params.k = 1;
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX2, trainY2, testX2, params);
if norm(trainInfo.predictedY - trainY2) > tol
error('Training labels predicted incorrectly for multi label case');
end
predictedTestY = zeros(numTestExamples, numLabels);
rankings = zeros(numTestExamples, 1);
distances = zeros(numTrainExamples, 1);
for i=1:numTestExamples
for j=1:numTrainExamples
distances(j) = norm(trainX2(j, :) - testX2(i, :))^2;
end
[distances, indices] = sort(distances);
predictedTestY(i, :) = trainY2(indices(1), :);
rankings(i) = sum(exp(-distances(1)) .* trainY2(indices(1), :));
end
if norm(predictedTestY - testInfo.predictedY) > tol
error('Test labels predicted incorrectly for multi label case');
end
if norm(rankings - testInfo.rankings) > tol
error('Test rankings predicted incorrectly for multi label case');
end
%Test if the method can work with 0/1 labels too
trainY2 = (trainY2+1)/2;
predictedTestY = (predictedTestY+1)/2;
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX2, trainY2, testX2, params);
if norm(predictedTestY - testInfo.predictedY) > tol
error('Test labels predicted incorrectly for multi label case for 0/1 labels');
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -