📄 evaluatebinarypredictions.m.svn-base
字号:
function [trainInfo, testInfo] = evaluateBinaryPredictions(trainData, testData, trainPrediction, testPrediction);
%Compute some additional statistics on the predicted labels
trainY = getDataFieldValue(trainData, 'Y');
testY = getDataFieldValue(testData, 'Y');
trainInfo = struct;
testInfo = struct;
%Can compute accuracy on multi labels
trainInfo.accuracy = accuracy(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.accuracy = accuracy(testY, testPrediction.predictedY, testPrediction.rankings);
if size(trainY, 2) ~= 1
trainInfo.indAccuracies = individualAccuracy(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.indAccuracies = individualAccuracy(testY, testPrediction.predictedY, testPrediction.rankings);
trainInfo.accuracyHist = accuracyHist(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.accuracyHist = accuracyHist(testY, testPrediction.predictedY, testPrediction.rankings);
trainY = trainY(:);
testY = testY(:);
trainPrediction.predictedY = trainPrediction.predictedY(:);
testPrediction.predictedY = testPrediction.predictedY(:);
%Compute accuracy per label
else %These measures need rankings too
if binaryLabels(trainY)
trainInfo.AUC = AUCWrapper(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.AUC = AUCWrapper(testY, testPrediction.predictedY, testPrediction.rankings);
trainInfo.averagePrecision = averagePrecision(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.averagePrecision = averagePrecision(testY, testPrediction.predictedY, testPrediction.rankings);
trainInfo.ROC = ROCvector(trainY, trainPrediction.predictedY, trainPrediction.rankings, 100);
testInfo.ROC = ROCvector(testY, testPrediction.predictedY, testPrediction.rankings, 100);
end
end
trainInfo.precision = precision(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.precision = precision(testY, testPrediction.predictedY, testPrediction.rankings);
trainInfo.recall = recall(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.recall = recall(testY, testPrediction.predictedY, testPrediction.rankings);
trainInfo.fMeasure = fMeasure(trainY, trainPrediction.predictedY, trainPrediction.rankings);
testInfo.fMeasure = fMeasure(testY, testPrediction.predictedY, testPrediction.rankings);
trainInfo.falsePositiveRate = falsePositiveRate(trainY, trainPrediction.predictedY);
testInfo.falsePositiveRate = falsePositiveRate(testY, testPrediction.predictedY);
trainInfo.truePositiveRate = truePositiveRate(trainY, trainPrediction.predictedY);
testInfo.truePositiveRate = truePositiveRate(testY, testPrediction.predictedY);
trainInfo.numPredictedPositives = sum(trainPrediction.predictedY > 0);
testInfo.numPredictedPositives = sum(testPrediction.predictedY > 0);
trainInfo.numPredictedNegatives = sum(trainPrediction.predictedY <= 0);
testInfo.numPredictedNegatives = sum(testPrediction.predictedY <= 0);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -