⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 get_threshold.m

📁 BPMLL is a package for training multi-label BP neural networks. The package includes the MATLAB code
💻 M
字号:
function Threshold=get_threshold(train_data,train_target,test_data,net)
%get_threshold gets the threshold value for each testing instance
%    The same strategy as described in [1] and [2] is used to train the threshold predictor 
%
%    Syntax
%
%       Threshold=get_threshold(train_data,train_target,test_data,net)
%
%    Description
%
%       get_threshold takes,
%           train_data       - An M1xN array, the ith instance of training instance is stored in train_data(i,:)
%           train_target     - A QxM1 array, if the ith training instance belongs to the jth class, then train_target(j,i) equals +1, otherwise train_target(j,i) equals -1
%           test_data        - An M2xN array, the ith instance of testing instance is stored in test_data(i,:)
%           net              - The trained multi-label neural network
%      and returns,
%           Threshold        - The threshold of the ith testing instance for assessing class membership is stored in Threshold(1,i)
%
%[1] Elisseeff A, Weston J. Kernel methods for multi-labelled classfication and categorical regression problems. Technical Report, BIOwulf Technologies, 2001.
%[2] Elisseeff A,Weston J. A kernel method for multi-labelled classification. In: Dietterich T G, Becker S, Ghahramani Z, eds. Advances in Neural Information Processing Systems 14, Cambridge, MA: MIT Press, 2002, 681-687.

    [num_class,num_training]=size(train_target);
    [num_testing,tempvalue]=size(test_data);
    
    Label=cell(num_training,1);
    not_Label=cell(num_training,1);
    for i=1:num_training
        temp=train_target(:,i);
        for j=1:num_class
            if(temp(j)==1)
                Label{i,1}=[Label{i,1},j];
            else
                not_Label{i,1}=[not_Label{i,1},j];
            end
        end
    end

    Left=sim(net,train_data')';
    Right=zeros(num_training,1);
    for i=1:num_training
        temp=Left(i,:);
        [temp,index]=sort(temp);
        candidate=zeros(1,num_class+1);
        candidate(1,1)=temp(1)-0.1;
        for j=1:num_class-1
            candidate(1,j+1)=(temp(j)+temp(j+1))/2;
        end
        candidate(1,num_class+1)=temp(num_class)+0.1;
        miss_class=zeros(1,num_class+1);
        for j=1:num_class+1
            temp_notlabels=index(1:j-1);
            temp_labels=index(j:num_class);
            [tempvalue,false_neg]=size(setdiff(temp_notlabels,not_Label{i,1}));
            [tempvalue,false_pos]=size(setdiff(temp_labels,Label{i,1}));
            miss_class(1,j)=false_neg+false_pos;
        end
        [temp_minimum,temp_index]=min(miss_class);
        Right(i,1)=candidate(1,temp_index);
    end
    Left=[Left,ones(num_training,1)];
    tempvalue=(Left\Right)';
    Weights_sizepre=tempvalue(1:num_class);
    Bias_sizepre=tempvalue(num_class+1);
    
    Outputs=sim(net,test_data');
    Threshold=([Outputs',ones(num_testing,1)]*[Weights_sizepre,Bias_sizepre]')';

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -