⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 train_test_multiple_class.m

📁 一个matlab的工具包,里面包括一些分类器 例如 KNN KMEAN SVM NETLAB 等等有很多.
💻 M
字号:
% Input pararmeter: 
% D: data array, including the feature data and output class

function run = train_test_multiple_class(X, Y, trainindex, testindex, classifier)

global preprocess;

% The statistics of dataset
num_class = length(preprocess.ClassSet);
actual_num_class = length(preprocess.OrgClassSet);
class_set = preprocess.ClassSet;

coding_matrix = GenerateCodeMatrix(preprocess.MultiClass.CodeType, actual_num_class);
coding_len = size(coding_matrix, 2);
Y_coding_matrix = (Y == class_set(1)) * coding_matrix;

X_train = X(trainindex, :);
Y_train_matrix = Y(trainindex, :);
Y_train_coding_matrix = Y_coding_matrix(trainindex, :);
X_test = X(testindex, :);
Y_test_matrix = Y(testindex, :);
Y_test_coding_matrix = Y_coding_matrix(testindex, :);

test_len = size(Y_test_coding_matrix, 1);
Y_compute_matrix = zeros(test_len, actual_num_class);
Y_uncertainty = zeros(test_len, 1);
Y_compute_coding_matrix = zeros(test_len, coding_len);
for j = 1:coding_len
    Y_train_coding = Y_train_coding_matrix(:, j);
    Y_test_coding = Y_test_coding_matrix(:, j);
        
    % Delete the element of which the label is zero
    X_train_norm = X_train(Y_train_coding ~= 0, :);
    Y_train_coding_norm = Y_train_coding(Y_train_coding ~= 0, :);
    
    % Converting the label back to class_set
    conv_Y_train_coding_norm = class_set(1) * (Y_train_coding_norm == 1)  +  class_set(2) * (Y_train_coding_norm == -1);
    conv_Y_test_coding = class_set(1) * (Y_test_coding == 1) + class_set(2) * (Y_test_coding == -1);    
    [Y_compute, Y_prob] = Classify(classifier, X_train_norm, conv_Y_train_coding_norm, X_test, conv_Y_test_coding, num_class);  
    CalculatePerformance(Y_compute, Y_test_coding, class_set);
    % Y_compute_coding_matrix(:, j) = Y_prob - preprocess.SVMSCutThreshold;
    Y_compute_coding_matrix(:, j) = 2 * (Y_prob .* (Y_compute == class_set(1)) + (1 - Y_prob) .* (Y_compute == class_set(2))) - 1;
end;

for j = 1: test_len
    for k = 1:actual_num_class
        dl = Y_compute_coding_matrix(j, :) .* coding_matrix(k, :);
        switch (preprocess.MultiClass.LossFuncType)
        case 0 
            loss = 1 ./ (1 + exp(2 * dl)); 
        case 1
            loss = exp(-dl);
        case 2
            loss = (dl <= 1) .* (1 - dl);
        end;
        Y_compute_matrix(j, k) = sum(loss); % Loss Function
    end;
end;

[Y_loss Y_loss_index] = min(Y_compute_matrix, [], 2);
for j = 1: test_len
    Y_compute_matrix(j, :) = (Y_compute_matrix(j, :) == Y_loss(j)) * class_set(1) + (Y_compute_matrix(j, :) ~= Y_loss(j)) * class_set(2);
end;

for j = 1:actual_num_class
    Y_compute = Y_compute_matrix(:, j); 
    Y_test = Y_test_matrix(:, j);
    [run_class.yy(j), run_class.yn(j), run_class.ny(j), run_class.nn(j), run_class.prec(j), run_class.rec(j), run_class.F1(j),...
        run_class.err(j)] = CalculatePerformance(Y_compute, Y_test, class_set);
end  
    
[Y_compute, junk] = find(Y_compute_matrix');
[Y_test, junk] = find(Y_test_matrix');
run.Y_compute = Y_compute; run.Y_prob = Y_loss; run.Y_test = Y_test;
% Aggregate the predictions in a shot
if (preprocess.ShotAvailable == 1), [Y_compute, Y_prob, Y_test] = AggregatePredByShot(Y_compute, Y_prob, Y_test, testindex); end;  
[junk, junk, junk, junk, run.Micro_Prec, run.Micro_Rec, run.Micro_F1, run.Err] = CalculatePerformance(Y_compute, Y_test, preprocess.OrgClassSet);

run.Macro_Prec = sum(run_class.prec) / actual_num_class;
run.Macro_Rec = sum(run_class.rec) / actual_num_class;
run.Macro_F1 = NormalizeRatio(2 * run.Macro_Prec * run.Macro_Rec, run.Macro_Prec + run.Macro_Rec);
%run.Micro_Prec = NormalizeRatio(sum(run_class.yy), sum(run_class.yy) + sum(run_class.ny)); 
%run.Micro_Rec = NormalizeRatio(sum(run_class.yy), sum(run_class.yy) + sum(run_class.yn));  
%run.Micro_F1 = NormalizeRatio(2 * run.Micro_Prec * run.Micro_Rec, run.Micro_Prec + run.Micro_Rec);
%run.Err = 1 - NormalizeRatio(sum(run_class.yy), size(Y_test, 1)); ;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -