⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 inst_mi_svm.m

📁 Multiple INstance Learning Library
💻 M
字号:
function [test_bag_label, test_inst_label, test_bag_prob, test_inst_prob] = inst_MI_SVM(para, train_bags, test_bags)

global preprocess;
global temp_train_file temp_test_file temp_output_file temp_model_file libSVM_dir; 

num_train_bag = length(train_bags);
num_test_bag = length(test_bags);

%set the initial instance labels to bag labels
idx = 0;
for i=1:num_train_bag
    num_inst = size(train_bags(i).instance, 1);
    train_label(idx+1 : idx+num_inst) = repmat(train_bags(i).label, num_inst, 1);    
    idx = idx + num_inst;
end

[train_instance, dummy] = bag2instance(train_bags);
[test_instance, dummy] = bag2instance(test_bags);

num_train_inst = size(train_instance, 1);
num_test_inst = size(test_instance, 1);

if isempty(train_bags)
    if (~isfield(preprocess, 'model_file') || isempty(preprocess.model_file))
        error('The model file must be provided in the train_only setting!');
    end;
    eval(['!copy ' preprocess.model_file ' ' temp_model_file ]);
    [test_label_predict, test_prob_predict] = LibSVM(para, [], [], test_instance, ones(num_test_inst, 1));    
else
    
    step = 1;
    past_train_label(step,:) = train_label;
    
    while 1
        %num_pos_label = sum(train_label == 1);
        %num_neg_label = sum(train_label == 0);
        %new_para = sprintf(' -NegativeWeight %.10g', (num_pos_label / num_neg_label));
        
        [all_label_predict, all_prob_predict] = LibSVM(para, train_instance, train_label, [train_instance; test_instance], ones(num_train_inst+num_test_inst, 1));
        train_label_predict = all_label_predict(1 : num_train_inst);
        train_prob_predict = all_prob_predict(1 : num_train_inst);
        test_label_predict = all_label_predict(num_train_inst+1 : num_train_inst+ num_test_inst);
        test_prob_predict = all_prob_predict(num_train_inst+1 : num_train_inst+ num_test_inst);

        idx = 0;
        for i=1:num_train_bag
            num_inst = size(train_bags(i).instance, 1);

            if train_bags(i).label == 0
                train_label(idx+1 : idx+num_inst) = zeros(num_inst, 1);
            else
                if any(train_label_predict(idx+1 : idx+num_inst))
                    train_label(idx+1 : idx+num_inst) = train_label_predict(idx+1 : idx+num_inst);
                else
                    [sort_prob, sort_idx] = sort(train_prob_predict(idx+1 : idx+num_inst));
                    train_label(idx+1 : idx+num_inst) = zeros(num_inst, 1);
                    train_label(idx + sort_idx(num_inst)) = 1;
                end
            end
            idx = idx + num_inst;
        end
        
        difference = sum(past_train_label(step,:) ~= train_label);
        fprintf('Number of label changes is %d\n', difference);
        if difference == 0, break; end;
         
        repeat_label = 0;
        for i = 1 : step
            if all(train_label == past_train_label(i, :))
                repeat_label = 1;
                break;
            end               
        end

        if repeat_label == 1
            fprintf('Repeated training labels found, quit...\n');
            break; 
        end

        step = step + 1;
        past_train_label(step, :) = train_label;
         
    end    
end

%prediction
test_inst_label = test_label_predict;
test_inst_prob = test_prob_predict;

idx = 0;
test_bag_label = zeros(num_test_bag, 1);
for i=1:num_test_bag
    num_inst = size(test_bags(i).instance, 1);    
    test_bag_label(i) = any(test_inst_label(idx+1 : idx+num_inst));
    test_bag_prob(i) = max(test_inst_prob(idx+1 : idx+num_inst));
    idx = idx + num_inst;
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -