📄 decisionstumpforrealboost.m
字号:
%%% Patrick Etyngier
%%% etyngier@certis.enpc.fr
%%% http://cermics.enpc.fr/~etyngier
%%% ENPC - Ecole Nationale des Ponts et Chaussees
%%% http://www.enpc.fr
%%% CERTIS - Centre d'Enseignement en Recherche en Technologie de
%%% l'Information et Systemes
%%% http://cermics.enpc.fr/~certis
%%% Copyright CERTIS
%%%
%%% V.9 decisionStumpForRealBoost.m
function [c_,d_,w_,classOne_,classTwo_] = decisionStumpForRealBoost(X,y,weight,h)
%%
%% | X1^T | | y1 | | w1 |
%% X = | | / y = | | / weight = | |
%% | XN^T | | yN | | wN |
%%
if nargin < 2
error('[c_,d_] = decisionStump(X,y,weight) : Argument error');
end
[N , dimFS] = size(X);
if nargin < 3
weight = ones(N,1)/N;
end
sign_ = [1 -1];
c_ = -1;
d_ = -1;
trueClassOne = find(y==1);
trueClassTwo = find(y==-1);
maxLog = Inf;
old = N;
indi = 1;
condi = 0;
condi = 0;
if condi
figure(h)
clf
end
fprintf('\t');
fprintf('Decision Stump / dimension : ')
for d = 1:dimFS
fprintf('%.4d',d);
vs = sort(X(:,d));
mtt = mean([vs,circshift(vs,1)],2);
thresholds = mtt(2:end);
TC = circshift(thresholds,1);
thresholds = thresholds(find(thresholds-TC~=0));
di = floor(length(thresholds));
temp = randperm(length(thresholds));
thresholds = thresholds(temp(1:di));
for c = 1:length(thresholds)
%( c==66 ) & ( d == 1 ) ;
s=1;
% for s = 1:length(sign_)
line = [];
line(d,:) = sign_(s)*1;
line(dimFS+1,:) = - sign_(s)*thresholds(c);
predY = pksign( [X ones(N,1)] * line );
tmp(find(y.*predY~=-1),:) = 1;
tmp(find(y.*predY==-1),:) = 0;
n11 = sum(weight(trueClassOne).*tmp(trueClassOne));
n21 = sum(weight(trueClassTwo).*~tmp(trueClassTwo));
n12 = sum(weight(trueClassOne).*~tmp(trueClassOne));
n22 = sum(weight(trueClassTwo).*tmp(trueClassTwo));
logL = sqrt(n22*n12) + sqrt(n11*n21);
% fprintf('logl %d ',logL);
% fprintf('\n')
% fprintf('maxlog %d ',maxLog);
% fprintf('\n')
% fprintf('sum %d ',sum(~tmp));
% fprintf('\n')
% fprintf('old %d ' ,old);
% fprintf('\n')
if logL < maxLog | (logL == maxLog & sum(~tmp) < old )
maxLog = (logL);
old = sum(~tmp);
if d ~= -1
c_ = thresholds(c);
d_ = d;
w_ = sign_(s);
classOne_ = find(predY ~= -1);
classTwo_ = find(predY == -1);
end
end
indi = indi+1;
end
fprintf('\b\b\b\b');
end
fprintf('\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b');
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -