📄 lssvcbay_predict.m
字号:
function [Y_resu, Y_conf, ytst0, varztstp, varztstn] = lssvcBay_predict(X_test, param, idx_feat, X_train, Y_train)
% [Y_resu, Y_conf, ytst0, varztstp, varztstn] = lssvcBay_predict(X_test, param, idx_feat, X_train, Y_train)
% Make classification predictions with the lssvcBay model.
% Inputs:
% X_test -- Test data matrix of dim (num test examples, num features).
% param -- Classifier, see lssvcBay_train.
% idx_feat -- Indices of the features selected.
% X_train -- Training data matrix of dim (num training examples, num features).
% -- used by some predictors (not lssvcBay though).
% Y_train -- Training labels (num training examples, 1).
% -- used by some predictors (not lssvcBay though).
% Returns:
% Y_resu -- +-1 predictions on the test data of dim (num test example).
% Y_conf -- Confidence values, i.e. the probabilty output for positive class in case of lssvcBay.
% ytst0 -- latent output of the lssvcBay model for test data
% varztstp -- sample specific variance resulted from the uncertainty in model parameters associated to + class
% varztstn -- sample specific variance resulted from the uncertainty in model parameters associated to - class
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
lssvcB=param; clear param; %t=Y_train;
if ~exist('idx_feat'),
idx_feat=[1:size(X_test,2)];
end
[varztstp, varztstn, ytst0] = lssvcmodoutb2_test(X_test(:,idx_feat), lssvcB);
zmp=lssvcB.zmp; zmn=lssvcB.zmn; zetap=lssvcB.zetap; zetan=lssvcB.zetan;
if isfield(lssvcB,'prior')
pip=lssvcB.prior(2); pin=lssvcB.prior(1);
else
pip=lssvcB.pip; pin=lssvcB.pin;
end
%pin=lssvcB.pip; pip=lssvcB.pin; % give higher prior on the classes with
%few number
[classtst, Pyptst, Pyntst] = lssvmoutclass2(ytst0, zmp, zmn, varztstp, varztstn, zetap, zetan, pip, pin);
Y_score=Pyptst-0.5;
Y_resu=sign(Y_score);
Y_conf=Pyptst;
return,
% performance on training set
ytrn0=lssvcB.ztrn; varztrnp=lssvcB.varztrnp; varztrnn=lssvcB.varztrnn;
[classtrn, Pyptrn, Pyntrn] = lssvmoutclass2(ytrn0, zmp, zmn, varztrnp, varztrnn, zetap, zetan, pip, pin);
[TN1,TP1,FP1,FN1,auctrn,se]=roc_tf(Pyptrn(t<=0),Pyptrn(t>0));
acctrn=sum(sign(classtrn)==t)/length(t);
bertrn=0.5*(sum(classtrn(t>0)<=0)/sum(t>0)+sum(classtrn(t<=0)>0)/sum(t<=0));
fprintf('Train: ACC=%.2f%% AUC=%.4f BER=%.2f%% \n', acctrn*100, auctrn, bertrn*100);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -