primalnearestneighbourspredict.m.svn-base

来自「a function inside machine learning」· SVN-BASE 代码 · 共 35 行

SVN-BASE
35
字号
function [testInfo, predictionInfo] = primalNearestNeighboursPredict(trainX, trainY, testX, modelInfo, params); 
%Predict using the primal kNN algorithm 

numTrainExamples = size(trainX, 1); 
numTestExamples = size(testX, 1); 
numLabels = size(trainY, 2); 

k = params.k; 

tic; 
e = ones(numTrainExamples, 1);
e2 = ones(numTestExamples, 1); 
diagTrainK = sum(trainX.^2, 2); 
diagTestK = sum(testX.^2, 2); 
distanceMatrix = diagTrainK*e2' + e*diagTestK' - 2*(trainX*testX');

maxEntry = max(max(distanceMatrix));
distances = zeros(numTestExamples, k); 
indices = zeros(numTestExamples, k); 
neighbours = zeros(numTestExamples, numLabels, k); 

for i=1:k
    [distances(:, i), indices(:, i)] = min(distanceMatrix, [], 1);
    
    distanceMatrix(sub2ind(size(distanceMatrix), indices(:, i), (1:numTestExamples)')) = maxEntry; 
    neighbours(:, :, i) = trainY(indices(:, i), :); 
end 

testInfo.predictedY = mode(neighbours, 3);
testInfo.rankings = sum(exp(-distances) .* squeeze(sum(neighbours, 2)), 2);
predictTime = toc; 
fprintf('Prediction time: %f s\n', predictTime); 

predictionInfo = struct;
predictionInfo.predictTime = predictTime; 

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?