⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 knntest.m.svn-base

📁 a function inside machine learning
💻 SVN-BASE
字号:
%Test k-nearest neighbours 

numExamples = 15; 
numFeatures = 10; 

noise = 1; 
tol = 10^-5; 

X = rand(numExamples, numFeatures); 
X = centerData(X); 
X = normalise(X); 
q = rand(numFeatures, 1); 
y = sign(X*q + noise*(rand(numExamples, 1)-0.5));

[trainX, trainY, testX, testY] = splitData(X, y, 2/3); 
numTrainExamples = size(trainX, 1); 
numTestExamples = size(testX, 1); 

params.k = 3; 
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, trainX, params); 

if norm(trainInfo.predictedY - testInfo.predictedY) > tol 
    error('Incorrect predictions on training/test set'); 
end 

if norm(trainInfo.rankings - trainInfo.rankings) > tol 
    error('Incorrect rankings on training/test set'); 
end 

params.k = 1; 
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, trainX, params); 

if norm(trainInfo.predictedY - trainY) > tol 
    error('1-NN predicting incorrectly'); 
end 

%Check if it catches error when try to use 2 labels 
params.k = 2; 

try
   [trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, trainX, params); 
catch
   disp('Correctly identified problem of using even number of labels'); 
end

%Now test if the predictions are correct 
params.k = 3; 
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, testX, params); 

predictedTestY = zeros(numTestExamples, 1); 
rankings = zeros(numTestExamples, 1); 
distances = zeros(numTrainExamples, 1); 

for i=1:numTestExamples 
    for j=1:numTrainExamples 
        distances(j) = norm(trainX(j, :) - testX(i, :))^2; 
    end 
    
    [distances, indices] = sort(distances);
    predictedTestY(i) = sign(sum(trainY(indices(1:3)))); 
    rankings(i) = sum(exp(-distances(1:3)) .* trainY(indices(1:3))); 
end 

if norm(predictedTestY - testInfo.predictedY) > tol 
    error('Test labels predicted incorrectly'); 
end 

if norm(rankings - testInfo.rankings) > tol 
    error('Test rankings predicted incorrectly'); 
end 

%Test if the method can work with 0/1 labels too 
trainY = (trainY+1)/2;
predictedTestY = (predictedTestY+1)/2;
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX, trainY, testX, params); 

if norm(predictedTestY - testInfo.predictedY) > tol 
    error('Test labels predicted incorrectly for 0/1 labels'); 
end 

%Test that it can pick up the correct examples in the multi label case 
numLabels = 3; 
trainX2 = rand(numTrainExamples, numFeatures); 
trainY2 = sign(rand(numTrainExamples, numLabels)-0.5); 
testX2 = rand(numTestExamples, numFeatures);  
params.k = 1;

[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX2, trainY2, testX2, params); 

if norm(trainInfo.predictedY - trainY2) > tol 
    error('Training labels predicted incorrectly for multi label case'); 
end 

predictedTestY = zeros(numTestExamples, numLabels); 
rankings = zeros(numTestExamples, 1); 
distances = zeros(numTrainExamples, 1); 

for i=1:numTestExamples 
    for j=1:numTrainExamples 
        distances(j) = norm(trainX2(j, :) - testX2(i, :))^2; 
    end 
    
    [distances, indices] = sort(distances);
    predictedTestY(i, :) = trainY2(indices(1), :); 
    rankings(i) = sum(exp(-distances(1)) .* trainY2(indices(1), :)); 
end

if norm(predictedTestY - testInfo.predictedY) > tol 
    error('Test labels predicted incorrectly for multi label case'); 
end 

if norm(rankings - testInfo.rankings) > tol 
    error('Test rankings predicted incorrectly for multi label case'); 
end 


%Test if the method can work with 0/1 labels too 
trainY2 = (trainY2+1)/2;
predictedTestY = (predictedTestY+1)/2;
[trainInfo, testInfo, classifierInfo] = primalNearestNeighbours(trainX2, trainY2, testX2, params); 

if norm(predictedTestY - testInfo.predictedY) > tol 
    error('Test labels predicted incorrectly for multi label case for 0/1 labels'); 
end 

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -