aggregatepredbyshot.m

来自「一个matlab的工具包,里面包括一些分类器 例如 KNN KMEAN SVM 」· M 代码 · 共 35 行

M
35
字号

function [Y_compute, Y_prob, Y_test] = AggregatePredByShot(Y_compute, Y_prob, Y_test, testindex, class_set, threshold)

global preprocess;

if (nargin < 5), class_set = preprocess.OrgClassSet; end;
if (nargin < 6), threshold = 0.5; end;

ShotInfo = preprocess.ShotInfo(testindex);
sort_class_set = sort(class_set);

ShotIDTestSet = unique(ShotInfo);
Y_aggregate_compute = zeros(size(ShotIDTestSet));
Y_aggregate_prob = zeros(size(ShotIDTestSet));
Y_aggregate_test = zeros(size(ShotIDTestSet));
for j = 1:length(ShotIDTestSet)
    Y_prob_shotID = Y_prob(ShotInfo == ShotIDTestSet(j));
    Y_aggregate_prob(j) = sum(Y_prob_shotID)/length(Y_prob_shotID);
    
    % Y_aggregate_compute(j) = class_set(1) * (Y_aggregate_prob(j) >= threshold) + class_set(2) * (Y_aggregate_prob(j) < threshold);
    labelhist = histc(Y_compute(ShotInfo == ShotIDTestSet(j)), sort_class_set);
    [junk, index] = max(labelhist);
    Y_aggregate_compute(j) = sort_class_set(index);
    
    %Y_test_shotID = Y_test(ShotInfo == ShotIDTestSet(j));
    %Y_aggregate_test(j) = class_set(1) * (sum(Y_test_shotID == class_set(1)) > 0) + class_set(2) * (sum(Y_test_shotID == class_set(1)) == 0);
    labelhist = histc(Y_test(ShotInfo == ShotIDTestSet(j)), sort_class_set);
    [junk, index] = max(labelhist);
    Y_aggregate_test(j) = sort_class_set(index);    
end;
Y_compute = Y_aggregate_compute;
Y_prob = Y_aggregate_prob;
Y_test = Y_aggregate_test;

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?