📄 train_test_validate.m
字号:
% Input pararmeter:
% D: data array, including the feature data and output class
% outputfile: the output file name of classifiers
function run = train_test_validate(D, classifier_wrapper_handle, classifier)
global preprocess;
[X, Y, num_data, num_feature] = preprocessing(D);
clear D;
% The statistics of dataset
% num_class = length(preprocess.ClassSet);
% actual_num_class = length(preprocess.OrgClassSet);
% num_shot = length(preprocess.ShotIDSet);
% class_set = preprocess.ClassSet;
if (preprocess.TrainTestSplitBoundary > 0),
splitboundary = preprocess.TrainTestSplitBoundary;
else
splitboundary = fix(num_data / (-preprocess.TrainTestSplitBoundary));
end;
testindex = splitboundary+1:num_data;
trainindex = 1:splitboundary;
% trainindex = []; testindex = [];
% for i = 1:length(preprocess.OrgClassSet),
% ind = find(Y == preprocess.OrgClassSet(i));
% sb = fix(2*length(ind)/3);
% trainindex = [trainindex; ind(1:sb)];
% testindex = [testindex; ind(sb+1:length(ind))];
% end;
run = feval(classifier_wrapper_handle, X, Y, trainindex, testindex, classifier);
run.Y_pred = zeros(length(testindex), 4);
run.Y_pred(:, 1) = (1:length(testindex))';
run.Y_pred(:, 2) = run.Y_prob;
run.Y_pred(:, 3) = run.Y_compute;
run.Y_pred(:, 4) = run.Y_test;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -