⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 evaluatebinarypredictions.m.svn-base

📁 a function inside machine learning
💻 SVN-BASE
字号:
function [trainInfo, testInfo] = evaluateBinaryPredictions(trainData, testData, trainPrediction, testPrediction); 
%Compute some additional statistics on the predicted labels

trainY = getDataFieldValue(trainData, 'Y'); 
testY = getDataFieldValue(testData, 'Y'); 

trainInfo = struct; 
testInfo = struct; 

%Can compute accuracy on multi labels 
trainInfo.accuracy = accuracy(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.accuracy = accuracy(testY, testPrediction.predictedY, testPrediction.rankings);

if size(trainY, 2) ~= 1
    trainInfo.indAccuracies = individualAccuracy(trainY, trainPrediction.predictedY, trainPrediction.rankings);
    testInfo.indAccuracies = individualAccuracy(testY, testPrediction.predictedY, testPrediction.rankings);
    
    trainInfo.accuracyHist = accuracyHist(trainY, trainPrediction.predictedY, trainPrediction.rankings);
    testInfo.accuracyHist = accuracyHist(testY, testPrediction.predictedY, testPrediction.rankings);
    
    trainY = trainY(:);
    testY = testY(:);
    trainPrediction.predictedY = trainPrediction.predictedY(:);
    testPrediction.predictedY = testPrediction.predictedY(:);
    
    %Compute accuracy per label 
else %These measures need rankings too
    if binaryLabels(trainY)
        trainInfo.AUC = AUCWrapper(trainY, trainPrediction.predictedY, trainPrediction.rankings);
        testInfo.AUC = AUCWrapper(testY, testPrediction.predictedY, testPrediction.rankings);

        trainInfo.averagePrecision = averagePrecision(trainY, trainPrediction.predictedY, trainPrediction.rankings);
        testInfo.averagePrecision = averagePrecision(testY, testPrediction.predictedY, testPrediction.rankings);

        trainInfo.ROC = ROCvector(trainY, trainPrediction.predictedY, trainPrediction.rankings, 100);
        testInfo.ROC = ROCvector(testY, testPrediction.predictedY, testPrediction.rankings, 100);
    end
end

trainInfo.precision = precision(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.precision = precision(testY, testPrediction.predictedY, testPrediction.rankings);

trainInfo.recall = recall(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.recall = recall(testY, testPrediction.predictedY, testPrediction.rankings);

trainInfo.fMeasure = fMeasure(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.fMeasure = fMeasure(testY, testPrediction.predictedY, testPrediction.rankings);

trainInfo.falsePositiveRate = falsePositiveRate(trainY, trainPrediction.predictedY);
testInfo.falsePositiveRate = falsePositiveRate(testY, testPrediction.predictedY);

trainInfo.truePositiveRate = truePositiveRate(trainY, trainPrediction.predictedY);
testInfo.truePositiveRate = truePositiveRate(testY, testPrediction.predictedY);

trainInfo.numPredictedPositives =  sum(trainPrediction.predictedY > 0);
testInfo.numPredictedPositives =  sum(testPrediction.predictedY > 0);

trainInfo.numPredictedNegatives = sum(trainPrediction.predictedY <= 0);
testInfo.numPredictedNegatives = sum(testPrediction.predictedY <= 0);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -