⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lssvcbay_predict_noprior.m

📁 The goal of SPID is to provide the user with tools capable to simulate, preprocess, process and clas
💻 M
字号:
function [Y_resu, Y_conf, ytst0, varztstp, varztstn] = lssvcBay_predict(X_test, param, idx_feat, X_train, Y_train)
% [Y_resu, Y_conf, ytst0, varztstp, varztstn] = lssvcBay_predict(X_test, param, idx_feat, X_train, Y_train)
% Make classification predictions with the lssvcBay model.
% Inputs:
% X_test -- Test data matrix of dim (num test examples, num features).
% param -- Classifier, see lssvcBay_train.
% idx_feat -- Indices of the features selected.
% X_train -- Training data matrix of dim (num training examples, num features).
%         -- used by some predictors (not lssvcBay though).
% Y_train -- Training labels (num training examples, 1).
%         -- used by some predictors (not lssvcBay though).
% Returns:
% Y_resu -- +-1 predictions on the test data of dim (num test example).
% Y_conf -- Confidence values, i.e. the probabilty output for positive class in case of lssvcBay.
% ytst0 -- latent output of the lssvcBay model for test data
% varztstp -- sample specific variance resulted from the uncertainty in model parameters associated to + class
% varztstn -- sample specific variance resulted from the uncertainty in model parameters associated to - class
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

lssvcB=param; clear param; %t=Y_train; 

if ~exist('idx_feat'),
    idx_feat=[1:size(X_test,2)];
end

[varztstp, varztstn, ytst0] = lssvcmodoutb2_test(X_test(:,idx_feat), lssvcB);

zmp=lssvcB.zmp; zmn=lssvcB.zmn; zetap=lssvcB.zetap; zetan=lssvcB.zetan;

pip=0.5; pin=0.5;

%pin=lssvcB.pip; pip=lssvcB.pin; % give higher prior on the classes with 
%few number

[classtst, Pyptst, Pyntst] = lssvmoutclass2(ytst0, zmp, zmn, varztstp, varztstn, zetap, zetan, pip, pin);
Y_score=Pyptst-0.5; 
Y_resu=sign(Y_score); 
Y_conf=Pyptst;

return,
% performance on training set
ytrn0=lssvcB.ztrn; varztrnp=lssvcB.varztrnp; varztrnn=lssvcB.varztrnn; 
[classtrn, Pyptrn, Pyntrn] = lssvmoutclass2(ytrn0, zmp, zmn, varztrnp, varztrnn, zetap, zetan, pip, pin);

[TN1,TP1,FP1,FN1,auctrn,se]=roc_tf(Pyptrn(t<=0),Pyptrn(t>0));   
acctrn=sum(sign(classtrn)==t)/length(t);
bertrn=0.5*(sum(classtrn(t>0)<=0)/sum(t>0)+sum(classtrn(t<=0)>0)/sum(t<=0));  
fprintf('Train: ACC=%.2f%% AUC=%.4f BER=%.2f%% \n', acctrn*100, auctrn, bertrn*100);  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -