⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 sampling.m

📁 最详尽的神经网络源码
💻 M
📖 第 1 页 / 共 2 页
字号:
function [Itrain , Itest , Ivalid] = sampling(X , y , options);

% Various Data sampling methods for evaluate Classifier Performances.
% 
% X                           : data (d x N)
% y                           : labels(1 x N)
% options.valid               : 1 = 3 sets in output : train, test, valid. 0 = 2 output : train, test
% options.randomize           : 1 = do a randomize permutation before splitting set
% options.fraction            : 0<x<1 : build set only with the fraction of the entiere set (for huge dataset)
% options.maxperclass         : Max samples per class. A scalar or a (1 x m) vector
% options.method              : 1 = Hold out
%                               2 = Bootstrap
%                               3 = K Cross-validation
%                               4 = Leave One Out
%                               5 = Stratified Cross Validation
%                               6 = Balanced Stratified Cross Validation
%                               7 = Stratified Hold out
%                               8 = Stratified Boot Strap
%
%
% load iris
% [Itrain , Itest , Ivalid] = sampling(X , y);
% 
% Author : S閎astien PARIS : sebastien.paris@lsis.org
% -------

if (nargin < 3)

    options.valid       = 0;

    options.randomize   = 1;

    options.fraction    = 1;
    
    options.maxperclass = inf;
    
    options.seed        = -1;

    options.method      = 3;

    options.cv.K        = 10;

else

    if (~any(strcmp(fieldnames(options) , 'valid')))

        options.valid         = 0;

    end

    if (~any(strcmp(fieldnames(options) , 'randomize')))

        options.randomize     = 1;

    end

    if (~any(strcmp(fieldnames(options) , 'fraction')))

        options.fraction      = 1;

    end

    if (~any(strcmp(fieldnames(options) , 'maxperclass')))

        options.maxperclass   = inf;

    end
    
    if (~any(strcmp(fieldnames(options) , 'seed')))

        options.seed   = -1;

    end
    
    
    
    if (~any(strcmp(fieldnames(options) , 'method')))

        options.method        = 3;

    end

    if (~any(strcmp(fieldnames(options) , 'cv.K')))

        options.cv.K         = 10;

    end


end

[d , N]               = size(X);



if(options.seed ~= -1)
    
   rand('state' , options.seed); 
    
end

if (options.randomize)

    indN              = randperm(N);

    X                 = X(: , indN);

    y                 = y(indN);

end

if ((options.fraction > 0) && (options.fraction < 1))

    
    N                = round(options.fraction*N);

    indN             = (1 : N);
    
    X                = X(: , indN);

    y                = y(indN);

end

if(any(~isinf(options.maxperclass)))
   
    label            = unique(y);
    
    L                = length(label);
    
    if(numel(options.maxperclass) == 1)
        
        options.maxperclass = options.maxperclass*ones(1 , L);
        
    end
        
    ind              = cell(1 , L);
    
    indice           = [];

    for i = 1 : L

        ind{i}       = find(y == label(i));
        
        indice       = [indice , ind{i}(1:(min(length(ind{i}) , options.maxperclass(i))))];

    end

    X                = X(: , indice);
    
    y                = y(: , indice);
    
    [d , N]          = size(X);
    
    
end

indN             = (1 : N);


Ivalid               = [];

if(options.method == 0)
   
    Itrain           = indN;
    
    Itest            = [];
    
   
end

if(options.method == 1) %holding method sans remise

    rho              = options.holding.rho;

    K                = options.holding.K;

    if options.valid

        if(prod(size(rho)) == 1)

            rho     = [rho , (1-rho)/2 , (1-rho)/2];

        end

    end



    NN               = round(rho*N);

    Ntrain           = NN(1);

    if(options.valid)

        Ntest        = NN(2);

        Nvalid       = N - (Ntrain + Ntest);

        Ivalid       = zeros(K , Nvalid);

    else

        Ntest        = N - Ntrain;

    end

    Itrain           = zeros(K , Ntrain);

    Itest            = zeros(K , Ntest);


    for i = 1 : K

        temp          = randperm(N);

        Itrain(i , :) = temp(1:Ntrain);

        if(options.valid)

            Itest(i , :)  = temp((Ntrain+1):(Ntrain + Ntest));

            Ivalid(i , :) = temp(Ntrain + Ntest + 1:N);


        else

            Itest(i , :)  = temp(Ntrain + 1:N);

        end

    end

end

if(options.method == 2) %Bootstrap method avec remise


    rho               = options.bootstraping.rho;

    K                 = options.bootstraping.K;

    if options.valid

        if(prod(size(rho)) == 1)

            rho     = [rho , (1-rho)/2 , (1-rho)/2];

        end

    end


    NN               = round(rho*N);

    Ntrain           = NN(1);

    if(options.valid)

        Ntest        = NN(2);

        Nvalid       = N - (Ntrain + Ntest);

        Ivalid       = zeros(K , Nvalid);

    else

        Ntest        = N - Ntrain;

    end

    Itrain           = zeros(K , Ntrain);

    Itest            = zeros(K , Ntest);


    for i = 1 : K

        temp          = ceil(N*rand(1 , N));

        Itrain(i , :) = temp(1:Ntrain);

        if(options.valid)

            Itest(i , :)  = temp((Ntrain+1):(Ntrain + Ntest));

            Ivalid(i , :) = temp(Ntrain + Ntest + 1:N);


        else

            Itest(i , :)  = temp(Ntrain + 1:N);

        end


    end

end

if(options.method == 3) %K Fold Cross-Validation

    K      = options.cv.K;

    S      = floor(N/K);

    indN   = (1:N);

    Itrain = zeros(K , N - S - options.valid*S);

    Itest  = zeros(K , S);

    if (options.valid)

        Ivalid = zeros(K , S);

    end

    for i = 1 : K

        Itest(i , :)  = indN((i-1)*S+1:i*S);

        temp          = indN([i*S+1:N , 1:(i-1)*S]);

        if(options.valid)

            Ivalid(i , :) = temp(1:S);

            Itrain(i , :) = temp(S+1:end);

        else

            Itrain(i , :) = temp;

        end

    end

end


if(options.method == 4) % Leave One Out

    K      = N;

    S      = 1;

    indN   = (1:N);

    Itrain = zeros(K , N - S - options.valid*S);

    Itest  = zeros(K , S);

    if (options.valid)

        Ivalid = zeros(K , S);

    end

    for i = 1 : K

        Itest(i , :)  = indN((i-1)*S+1:i*S);

        temp          = indN([i*S+1:N , 1:(i-1)*S]);

        if(options.valid)

            Ivalid(i , :) = temp(1:S);

            Itrain(i , :) = temp(S+1:end);

        else

            Itrain(i , :) = temp;

        end

    end


end

if(options.method == 5) % Stratified Cross Validation

    K                              = options.cv.K;

    label                          = unique(y);

    select                         = histc(y , label);



    L                               = length(label);

    ind                             = cell(1 , L);

    n                               = zeros(1 , L);

    for i = 1 : L

        ind{i}                      = find(y == label(i));

    end

    Ntestc                          = floor((1/K).*select);

    if(options.valid)

        Nvalidc                       = Ntestc;

        Ntrainc                       = select  - 2*Ntestc;

        Ntrain                        = sum(Ntrainc);

        Ntest                         = sum(Ntestc);

        Nvalid                        = sum(Nvalidc);

        Ivalid                        = zeros(K , Nvalid);


    else

        Ntrainc                         = select  - Ntestc;

        Ntrain                          = sum(Ntrainc);

        Ntest                           = sum(Ntestc);

    end


    Itrain                          = zeros(K , Ntrain);

    Itest                           = zeros(K , Ntest);


    for j = 1 : K

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -