📄 aggregatepredbyshot.m
字号:
function [Y_compute, Y_prob, Y_test] = AggregatePredByShot(Y_compute, Y_prob, Y_test, testindex, class_set, threshold)
global preprocess;
if (nargin < 5), class_set = preprocess.OrgClassSet; end;
if (nargin < 6), threshold = 0.5; end;
ShotInfo = preprocess.ShotInfo(testindex);
sort_class_set = sort(class_set);
ShotIDTestSet = unique(ShotInfo);
Y_aggregate_compute = zeros(size(ShotIDTestSet));
Y_aggregate_prob = zeros(size(ShotIDTestSet));
Y_aggregate_test = zeros(size(ShotIDTestSet));
for j = 1:length(ShotIDTestSet)
Y_prob_shotID = Y_prob(ShotInfo == ShotIDTestSet(j));
Y_aggregate_prob(j) = sum(Y_prob_shotID)/length(Y_prob_shotID);
% Y_aggregate_compute(j) = class_set(1) * (Y_aggregate_prob(j) >= threshold) + class_set(2) * (Y_aggregate_prob(j) < threshold);
labelhist = histc(Y_compute(ShotInfo == ShotIDTestSet(j)), sort_class_set);
[junk, index] = max(labelhist);
Y_aggregate_compute(j) = sort_class_set(index);
%Y_test_shotID = Y_test(ShotInfo == ShotIDTestSet(j));
%Y_aggregate_test(j) = class_set(1) * (sum(Y_test_shotID == class_set(1)) > 0) + class_set(2) * (sum(Y_test_shotID == class_set(1)) == 0);
labelhist = histc(Y_test(ShotInfo == ShotIDTestSet(j)), sort_class_set);
[junk, index] = max(labelhist);
Y_aggregate_test(j) = sort_class_set(index);
end;
Y_compute = Y_aggregate_compute;
Y_prob = Y_aggregate_prob;
Y_test = Y_aggregate_test;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -