⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 train.m

📁 该文件里面包含了三个AdaBoost算法
💻 M
字号:
%   The algorithms implemented by Alexander Vezhnevets aka Vezhnick
%   <a>href="mailto:vezhnick@gmail.com">vezhnick@gmail.com</a>
%
%   Copyright (C) 2005, Vezhnevets Alexander
%   vezhnick@gmail.com
%   
%   This file is part of GML Matlab Toolbox
%   For conditions of distribution and use, see the accompanying License.txt file.
%
%   train Implements training of a classification tree
%~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
%
%    nodes = train(node, dataset, labels, weights)
%    ---------------------------------------------------------------------------------
%    Arguments:
%           node      - object of tree_node_w class (initialized properly)
%           dataset   - training data
%           labels    - training labels
%           weights   - weights of training data
%    Return:
%           nodes     - tree is represented as a cell array of its nodes

function nodes = train(node, dataset, labels, weights)

max_split = node.max_split;

[left right spit_error] = do_learn_nu(node, dataset, labels, weights);

nodes = {left, right};

left_pos  = sum((calc_output(left , dataset) == labels) .* weights);
left_neg  = sum((calc_output(left , dataset) == -labels) .* weights);
right_pos = sum((calc_output(right, dataset) == labels) .* weights);
right_neg = sum((calc_output(right, dataset) == -labels) .* weights);

errors = [min(left_pos, left_neg), min(right_pos, right_neg)];

if(right_pos == 0 && right_neg == 0)
  return;
end

if(left_pos == 0 && left_neg == 0)
  return;
end

[errors, IDX] = sort(errors);
errors = flipdim(errors,2);
IDX    = flipdim(IDX,2);
nodes  = nodes(IDX);


splits = [];
split_errors = [];
deltas = [];


for i = 2 : max_split
    for j = 1 : length(errors)
        
        if(length(deltas) >= j)
            continue;
        end
        
        max_node = nodes{j};
        max_node_out = calc_output(max_node, dataset);
       
        mask = find(max_node_out == 1);  
       
        [left right spit_error] = do_learn_nu(node, dataset(:,mask), labels(mask), weights(mask), max_node);
              
        
        left_pos  = sum((calc_output(left , dataset) == labels) .* weights);
        left_neg  = sum((calc_output(left , dataset) == -labels) .* weights);
        right_pos = sum((calc_output(right, dataset) == labels) .* weights);
        right_neg = sum((calc_output(right, dataset) == -labels) .* weights);
        
        splits{end+1} = left;
        splits{end+1} = right;  
        
        if( (right_pos + right_neg) == 0 || (left_pos + left_neg) == 0)
          deltas(end+1) = 0;
        else
          deltas(end+1) = errors(j) - spit_error;
        end
        
        split_errors(end+1) = min(left_pos, left_neg);
        split_errors(end+1) = min(right_pos, right_neg);
    end  
    
    if(max(deltas) == 0)
        return;
    end
    best_split = find(deltas == max(deltas));
    best_split = best_split(1);
    
    cut_vec = [1 : (best_split-1)  (best_split + 1) : length(errors)];
    nodes   = nodes(cut_vec);
    errors  = errors(cut_vec);
    deltas  = deltas(cut_vec);
    
    nodes{end+1} = splits{2 * best_split - 1};
    nodes{end+1} = splits{2 * best_split};
    
    errors(end+1) = split_errors(2 * best_split - 1);
    errors(end+1) = split_errors(2 * best_split);
    
    cut_vec = [1 : 2 * (best_split-1)  2 * (best_split)+1 : length(split_errors)];
    split_errors = split_errors(cut_vec);    
    splits       = splits(cut_vec);

end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -