⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 test_h2mglvq.m

📁 最详尽的神经网络源码
💻 M
字号:
clear, clc, close all hidden
disp('Running example 1 ....')

d                                     = 2;
Ntrain                                = 100;
m                                     = 2;
M0                                    = [0 ; 0];
R0                                    = [1 0 ; 0 1];
M1                                    = [2 ; 3];
R1                                    = [0.5 0.1 ; 0.2 1];
vect_test                             = (-4:0.1:8);
options.epsilonk                      = 0.01;
options.epsilonl                      = 0.01;
options.tmax                          = 1;
options.tmin                          = 0.5;
options.xi                            = 0.1;
options.nb_iterations                 = 5000;
options.shuffle                       = 1;



Xtrain                                = [M0(: , ones(1 , Ntrain/2)) + chol(R0)'*randn(d , Ntrain/2) , M1(: , ones(1 , Ntrain/2)) + chol(R1)'*randn(d , Ntrain/2)];
ytrain                                = [zeros(1 , Ntrain/2) , ones(1 , Ntrain/2)];
[X , Y]                               = meshgrid(vect_test);
Xtest                                 = [X(:)' ; Y(:)'];

Nproto_pclass                         = 4*ones(1 , length(unique(ytrain)));


[Wproto , yproto , lambda]            = ini_proto(Xtrain , ytrain , Nproto_pclass);
[Wproto_est , yproto_est   E_H2MGLVQ] = h2m_glvq_model(Xtrain , ytrain , Wproto , yproto  , options);
ytest_est                             = NN_predict(Xtest , Wproto_est , yproto_est, [] ,options);

indtrain0                             = (ytrain == 0);
indtrain1                             = (ytrain == 1);

indproto0                             = (yproto_est == 0);
indproto1                             = (yproto_est == 1);

figure(1)
imagesc(vect_test , vect_test , reshape(ytest_est , length(vect_test) , length(vect_test)) )
axis ij
hold on
plot(Xtrain(1 , indtrain0) , Xtrain(2 , indtrain0) , 'k+' , Xtrain(1 , indtrain1) , Xtrain(2 , indtrain1) , 'm+' , Wproto_est(1 , indproto0) ,  Wproto_est(2 , indproto0) , 'ko' , Wproto_est(1 , indproto1) ,  Wproto_est(2 , indproto1) , 'mo')
h = voronoi(Wproto_est(1 , :) , Wproto_est(2 , :));
set(h , 'color' , 'y' , 'linewidth' , 2)
hold off
title('E_{H2MGLVQ}(t)' , 'fontsize' , 12)
colorbar


figure(2)
plot(E_H2MGLVQ);
title('E_{H2MGLVQ}(t)' , 'fontsize' , 12)



pause
disp('Running example 2 ....')

clear, close all hidden

load ionosphere
Nproto_pclass                      = 4*ones(1 , length(unique(y)));


options.epsilonk                   = 0.005;
options.epsilonl                   = 0.001;
options.tmax                       = 1;
options.tmin                       = 0.2;
options.xi                         = 10;
options.nb_iterations              = 5000;
options.shuffle                    = 1;



options.method                     = 7;
options.holding.rho                = 0.7;
options.holding.K                  = 20;


X                                  = normalize(X);
[Itrain , Itest]                   = sampling(X , y , options);

Perftrain                          = zeros(1 , size(Itrain , 1));
Perftest                           = zeros(1 , size(Itrain , 1));
tptrain                            = zeros(size(Itrain , 1) , 100);
fptrain                            = zeros(size(Itrain , 1) , 100);
tptest                             = zeros(size(Itrain , 1) , 100);
fptest                             = zeros(size(Itrain , 1) , 100);

for i = 1 : size(Itrain , 1)

    [Xtrain , ytrain , Xtest , ytest]        = samplingset(X , y , Itrain , Itest , i);

    [Wproto , yproto , lambda]               = ini_proto(Xtrain , ytrain , Nproto_pclass);
    [Wproto_est , yproto_est , E_H2MGLVQ]    = h2m_glvq_model(Xtrain , ytrain , Wproto , yproto , options);

    [ytest_est , disttest]                   = NN_predict(Xtest , Wproto_est , yproto_est , [ ] , options);
    [ytrain_est , disttrain]                 = NN_predict(Xtrain , Wproto_est , yproto_est , [ ] , options);

    Perftrain(i)                             = perf_classif(ytrain , ytrain_est);
    Perftest(i)                              = perf_classif(ytest , ytest_est);

    dktrain                                  = min(disttrain(yproto==0 , :));
    dltrain                                  = min(disttrain(yproto~=0 , :));
    nutrain                                  = (dktrain - dltrain)./(dktrain + dltrain);
    [tptrain(i , :) , fptrain(i , :)]        = basicroc(ytrain , nutrain);


    dktest                                   = min(disttest(yproto==0 , :));
    dltest                                   = min(disttest(yproto~=0 , :));
    nutest                                   = (dktest - dltest)./(dktest + dltest);
    [tptest(i , :) , fptest(i , :)]          = basicroc(ytest , nutest);

    disp(sprintf('%d/%d' , i , options.holding.K))
    drawnow

end

disp('Performances Train/Test')
disp([mean(Perftrain,2) , mean(Perftest,2)])

figure(1)
plot(E_H2MGLVQ);
title('E_{H2MGLVQ}(t)' , 'fontsize' , 12)

figure(2)
plot(mean(fptrain , 1) , mean(tptrain , 1) , mean(fptest , 1) , mean(tptest , 1) , 'r' , 'linewidth'  , 2)
xlabel('false positive rate');
ylabel('true positive rate');
title('ROC curve','fontsize' , 12);
legend(['Train'] , ['Test'])

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -