⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 primalnearestneighbourspredict.m.svn-base

📁 a function inside machine learning
💻 SVN-BASE
字号:
function [testInfo, predictionInfo] = primalNearestNeighboursPredict(trainX, trainY, testX, modelInfo, params); 
%Predict using the primal kNN algorithm 

numTrainExamples = size(trainX, 1); 
numTestExamples = size(testX, 1); 
numLabels = size(trainY, 2); 

k = params.k; 

tic; 
e = ones(numTrainExamples, 1);
e2 = ones(numTestExamples, 1); 
diagTrainK = sum(trainX.^2, 2); 
diagTestK = sum(testX.^2, 2); 
distanceMatrix = diagTrainK*e2' + e*diagTestK' - 2*(trainX*testX');

maxEntry = max(max(distanceMatrix));
distances = zeros(numTestExamples, k); 
indices = zeros(numTestExamples, k); 
neighbours = zeros(numTestExamples, numLabels, k); 

for i=1:k
    [distances(:, i), indices(:, i)] = min(distanceMatrix, [], 1);
    
    distanceMatrix(sub2ind(size(distanceMatrix), indices(:, i), (1:numTestExamples)')) = maxEntry; 
    neighbours(:, :, i) = trainY(indices(:, i), :); 
end 

testInfo.predictedY = mode(neighbours, 3);
testInfo.rankings = sum(exp(-distances) .* squeeze(sum(neighbours, 2)), 2);
predictTime = toc; 
fprintf('Prediction time: %f s\n', predictTime); 

predictionInfo = struct;
predictionInfo.predictTime = predictTime; 

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -