📄 calc_posprob.m
字号:
% This file computes the probabilistic outputs of the SVM using the sigmoid
% method by J. Platt
% Used In Conjunction with D2CMatlab
% Written by D.Lai 2006
% Requires name of input file containing SVM decision values f(x). f(x) can
% be obtained using leave one out, crossvalidation or hold out methods.
%
function calc_posprob(h,handles)
F=load(handles.SVMdecisionfile);
[TotalRows TotalCol]=size(F);
OutputFile=handles.SVMposfile;
%determine number of positive examples and number of negative examples
NUMNEG=0;
NUMPOSITIVE=0;
for i=1:TotalRows
if (F(i,2)<0)
NUMNEG=NUMNEG+1;
end
end
NUMPOSITIVE=TotalRows-NUMNEG;
%output parameters
A=0;
B=0;
%Initialization
A=0;
B=log10((NUMNEG+1)/(NUMPOSITIVE+1));
hiTarget=(NUMPOSITIVE+1)/(NUMPOSITIVE+2);
loTarget=1/(NUMNEG+2);
lambda=0.001;
olderr=1e300;
pp=zeros(TotalRows,1);
temp=(NUMPOSITIVE+1)/(TotalRows+2);
for i=1:TotalRows
pp(i,1)=temp;
end
count=0;
%Begin iteration
for it=1:100
a=0;
b=0;
c=0;
d=0;
e=0;
%Compute the coefficients of the sum terms, Hessian and gradient terms
%of the max likelihood function
for i=1:TotalRows
if(F(i,2)==1) %if target is positive
t=hiTarget;
else
t=loTarget;
end
d1=pp(i)-t;
d2=pp(i)*(1-pp(i));
a=a+F(i,1)*F(i,1)*d2;
b=b+d2;
c=c+F(i,1)*d2;
d=d+F(i,1)*d1;
e=e+d1;
end
%Exit if gradient is small
if ( abs(d)<1e-9 ) & (abs(e) < 1e-9)
break;
end
oldA=A;
oldB=B;
err=0;
while(1)
det=(a+lambda)*(b+lambda)-c*c;
if(det==0)
lambda=lambda*10;
continue;
end
%Compute Step
A=oldA+((b+lambda)*d-c*e)/det;
B=oldB+((a+lambda)*e-c*d)/det;
%compute goodness of fit
err=0;
for i=1:TotalRows
p=1/(1+exp(A*F(i,1)+B));
pp(i)=p;
% Log 0 returns -200
logtemp1=log10(p);
if(p==0)
logtemp1=-200;
end
logtemp2=log10(1-p);
if(1-p==0)
logtemp2=-200;
end
err=err-t*logtemp1+(1-t)*logtemp2;
end
if(err<olderr*(1+1e-7))
lambda=lambda*0.1;
break;
end
lambda=lambda*10;
if(lambda>=1e6)
break;
end
end
diff=err-olderr;
scale=0.5*(err+olderr+1);
if((diff > -1e-3*scale) & (diff <1e-7*scale))
count=count+1;
else
count=0;
end
olderr=err;
if (count==3)
break;
end
end
fid=fopen(OutputFile, 'w+');
for i=1:TotalRows
fprintf(fid,'%f %f %f \n',F(i,1),F(i,2),pp(i));
end
fclose(fid)
handles.SVMPP=[F(1:TotalRows,2) pp];
handles.SVMPP=sortrows(handles.SVMPP,1);
guidata(h, handles);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -