📄 pkadaboost.asv
字号:
%%% Patrick Etyngier
%%% etyngier@certis.enpc.fr
%%% http://cermics.enpc.fr/~etyngier
%%% ENPC - Ecole Nationale des Ponts et Chaussees
%%% http://www.enpc.fr
%%% CERTIS - Centre d'Enseignement en Recherche en Technologie de
%%% l'Information et Systemes
%%% http://cermics.enpc.fr/~certis
%%% Copyright CERTIS
%%%
%%% V.9 pkAdaBoost.m
function [fm1,fm2,linSeparators,alpha_,goodThresholds] = pkAdaBoost(X,y,nbIteration)
[N,dimX]=size(X);
trueClassOne = find(y==1);
trueClassTwo = find(y==-1);
D(1:N,1) = 1/N;
m=1;
M = nbIteration;
fprintf('Boosting : Iteration -> ');
for m=1:M
fprintf('\b\b\b\b\b\b\b\b\b\b\b\b\b%.5d / %.5d',m,M)
[c(m),d(m),w(m),classOne,classTwo] = decisionStumpForRealBoost(X,y,D(:,m));
line(d(m),m) = w(m)*1;
line(dimX+1,m)=- w(m)*c(m);
predY = pksign([X,ones(N,1)]*line(:,m)); % prediction with current classifier
tmp(find(y.*predY~=-1),:) = 0;
tmp(find(y.*predY==-1),:) = 1;
tmp2 = ~ tmp ; %% tmp = missclassified
weight = D(:,m);
n11(m) = sum(weight(trueClassOne).* tmp2(trueClassOne));
n21(m) = sum(weight(trueClassTwo).*~tmp2(trueClassTwo));
n12(m) = sum(weight(trueClassOne).*~tmp2(trueClassOne));
n22(m) = sum(weight(trueClassTwo).* tmp2(trueClassTwo));
beta = 10^-6;%1/(4*N);
fm2(m) = .5*log( (n12(m)+beta) / (n22(m)+beta) );
fm1(m) = .5*log( (n11(m)+beta) / (n21(m)+beta) );
clear predYFin;
predYFin( 1:N ,:) = 0;
predYFin( find(predY~=-1) ) = fm1(m);
predYFin( find(predY==-1) ) = fm2(m);
err(:,m) = sum(D(:,m).*tmp)/N;
alpha_(:,m) = .5*log((1-err(:,m))/err(:,m));
D(:,m+1) = D(:,m).*exp(-y.*predYFin);
D(:,m+1) = D(:,m+1) / sum(D(:,m+1));
end
%%%% computation of the threshold
fprintf('\b\b\b\b\b\b\b\b\b\b\b\b\bDone \nComputation of the threshold : ');
regFunc = zeros(N,1);
for m=1:M
predY = pksign(w(m)*[X, ones(N,1)]*line(:,m));
predYFin = zeros(N,1);
predYFin( find(predY~=-1) ) = fm1(m);
predYFin( find(predY==-1) ) = fm2(m);
regFunc = regFunc+ alpha_(m) * predYFin;
%fprintf('t=%d\n',m);
end
% regFunc contains all level of" threshold
% and then we compute different possible threshold.
regFuncS = sort(regFunc);
regFuncSC = circshift(regFuncS,1);
diff = regFuncS'- regFuncSC';
mea = .5*(regFuncS'+ regFuncSC');
diff = diff(2:end);
mea = mea(2:end);
thresholds = mea(find(diff~=0));
min_missclassified = N;
goodThresholds = Inf;
mysign = [1 -1];
for i = 1 : length(thresholds)
for s=1:length(mysign)
cl_I = find(mysign(s)*(regFunc)>= mysign(s)*thresholds(i));
cl_II = find(mysign(s)*(regFunc)<mysign(s)*thresholds(i));
misclassified = length(find(y(cl_I)==-1)) + length( find(-y(cl_II)==-1));
miss(i,s) = misclassified;
if misclassified < min_missclassified
min_missclassified = misclassified;
goodThresholds = thresholds(i);
goodmysign = mysign(s);
end
end
end
linSeparators = line;
fprintf('Done\n');
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -