⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mil_train_validate.m

📁 包含了很多分类算法
💻 M
字号:
function run = MIL_Train_Validate(data_file, classifier)

global preprocess;
clear run;

% The statistics of dataset
% [X, Y, num_data, num_feature] = Preprocessing(D);
% num_class = length(preprocess.ClassSet);
% class_set = preprocess.ClassSet;
[bags, num_data, num_feature] = MIL_Data_Load(data_file);

% Extract the training and testing data
% X_train = X;
% Y_train = Y;
% Y_test = Y_train;
testindex = 1:num_data;

% Classify with Ensemble 
[test_bag_label, test_inst_label, test_bag_prob, test_inst_prob] = MIL_Classify(classifier, bags, bags);      
  
%run.Y_compute = Y_compute; run.Y_prob = Y_prob; run.Y_test = Y_train;
run.bag_label = test_bag_label;
run.inst_label = test_inst_label; 
run.bag_prob = test_bag_prob;
run.inst_prob = test_inst_prob;

% Aggregate the predictions in a shot
 
% Report the performance
run.BagAccu = MIL_Bag_Evaluate(bags(testindex), test_bag_label);
if ~isempty(test_inst_label)
    run.InstAccu = MIL_Inst_Evaluate(bags(testindex), test_inst_label);
end;

run.bag_pred = zeros(length(testindex), 3);
run.bag_pred(:, 1) = (1:length(testindex))';
run.bag_pred(:, 2) = run.bag_prob; 
run.bag_pred(:, 3) = run.bag_label; 
run.bag_pred(:, 4) = [bags(testindex).label]';

if (isfield(preprocess, 'EnforceDistrib') && preprocess.EnforceDistrib == 1)
   num_pos = 0;
   for i = 1:num_data, num_pos = num_pos + bags(i).label; end;
   [sort_ret, sort_idx ] = sort(run.bag_pred(:,2));
   threshold = sort_ret(num_data - num_pos + 1);   
   run.bag_pred(:, 3) = (run.bag_pred(:,2) >= threshold);
   run.BagAccu = sum(run.bag_pred(:,3) == run.bag_pred(:,4)) / num_data;
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -