⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mil_train_test_validate.m

📁 Multiple INstance Learning Library
💻 M
字号:
% Input pararmeter: 
% D: data array, including the feature data and output class
% outputfile: the output file name of classifiers
function run = MIL_train_test_validate(data_file, classifier_wrapper_handle, classifier)

global preprocess;

% [X, Y, num_data, num_feature] = preprocessing(D);
% clear D;
[bags, num_data, num_feature] = MIL_data_load(data_file);

if isfield(preprocess, 'test_file') && ~isempty(preprocess.test_file)
    [test_bags, num_test_data, num_feature] = MIL_data_load(preprocess.test_file);
    bags = [bags test_bags];
    splitboundary = num_data;
    num_data = num_data + num_test_data;    
else
    if (preprocess.TrainTestSplitBoundary > 0),
        splitboundary = preprocess.TrainTestSplitBoundary;
    else
        splitboundary = fix(num_data / (-preprocess.TrainTestSplitBoundary));
    end;
end;
testindex = splitboundary+1:num_data;
trainindex = 1:splitboundary;

run = feval(classifier_wrapper_handle, bags, trainindex, testindex, classifier);
  
run.bag_pred = zeros(length(testindex), 3);
run.bag_pred(:, 1) = (1:length(testindex))';
run.bag_pred(:, 2) = run.bag_prob; 
run.bag_pred(:, 3) = run.bag_label; 
run.bag_pred(:, 4) = [bags(testindex).label]';

if (isfield(preprocess, 'EnforceDistrib') && preprocess.EnforceDistrib == 1)
   num_pos = 0;
   for i = 1:num_data, num_pos = num_pos + bags(i).label; end;   
   num_pos = round((num_pos / num_data) * length(testindex));   %the expected # of pos bags in the testing data
   
   [sort_ret, sort_idx ] = sort(run.bag_pred(:,2));
   threshold = sort_ret(length(testindex) - num_pos + 1);   
   run.bag_pred(:, 3) = (run.bag_pred(:,2) >= threshold);   
   run.BagAccu = sum(run.bag_pred(:,3) == run.bag_pred(:,4)) / length(testindex);
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -