📄 cross_validate.m
字号:
% Input pararmeter:
% D: data array, including the feature data and output class
function run = cross_validate(D, classifier_wrapper_handle, classifier)
global preprocess;
[X, Y, num_data, num_feature] = preprocessing(D);
clear D;
% The statistics of dataset
num_folder = preprocess.NumCrossFolder;
num_class = length(preprocess.ClassSet);
class_set = preprocess.ClassSet;
run.Y_pred = zeros(num_data, 4);
run.Y_pred(:, 1) = (1:num_data)';
for i = 1:num_folder
fprintf('Iteration %d ......\n', i);
% Generate the data indeces for the testing data
testindex = floor((i-1) * num_data / num_folder)+1 : floor( i * num_data/num_folder);
if (preprocess.ShotAvailable == 1) & (preprocess.ValidateByShot == 1)
num_shot = length(preprocess.ShotIDSet);
ValidateTestShot = preprocess.ShotIDSet(floor((i-1) * num_shot / num_folder) + 1 : floor(i * num_shot / num_folder));
testindex = []; for j = 1:length(ValidateTestShot), testindex = [testindex; find(preprocess.ShotInfo == ValidateTestShot(j))]; end;
end;
trainindex = setdiff(1:num_data, testindex);
%%% RemoveConstraints;
% Classificaiton
run_class(i) = feval(classifier_wrapper_handle, X, Y, trainindex, testindex, classifier);
run.Y_pred(testindex, 2) = run_class(i).Y_prob;
run.Y_pred(testindex, 3) = run_class(i).Y_compute;
run.Y_pred(testindex, 4) = run_class(i).Y_test;
end
if (isfield(run_class(1), 'Err')), run.Err = mean([run_class(:).Err]); end;
if (isfield(run_class(1), 'Prec')), run.Prec = mean([run_class(:).Prec]); end;
if (isfield(run_class(1), 'Rec')), run.Rec = mean([run_class(:).Rec]); end;
if (isfield(run_class(1), 'F1')), run.F1 = mean([run_class(:).F1]); end;
if (isfield(run_class(1), 'Micro_Prec')), run.Micro_Prec = mean([run_class(:).Micro_Prec]); end;
if (isfield(run_class(1), 'Micro_Rec')), run.Micro_Rec = mean([run_class(:).Micro_Rec]); end;
if (isfield(run_class(1), 'Micro_F1')), run.Micro_F1 = mean([run_class(:).Micro_F1]); end;
if (isfield(run_class(1), 'Macro_Prec')), run.Macro_Prec = mean([run_class(:).Macro_Prec]); end;
if (isfield(run_class(1), 'Macro_Rec')), run.Macro_Rec = mean([run_class(:).Macro_Rec]); end;
if (isfield(run_class(1), 'Macro_F1')), run.Macro_F1 = mean([run_class(:).Macro_F1]); end;
if (isfield(run_class(1), 'AvgPrec')), run.AvgPrec = mean([run_class(:).AvgPrec]); end;
if (isfield(run_class(1), 'BaseAvgPrec')), run.BaseAvgPrec = mean([run_class(:).BaseAvgPrec]); end;
function RemoveConstraints()
global preprocess;
if (preprocess.ConstraintAvailable == 1) & (preprocess.ShotAvailable == 1)
for j = 1:size(preprocess.constraintMap, 1),
ShotInfo = preprocess.ShotInfo;
preprocess.constraintUsed(j) = (all(ShotInfo(trainindex) ~= preprocess.constraintMap(j,1)) && ...
all(ShotInfo(trainindex) ~= preprocess.constraintMap(j,2)));
end;
end;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -