📄 localridge.m
字号:
function [l, V, H, A, W, P] = localRidge(F, Y, l, options)% [l, V, H, A, W, P] = localRidge(F, Y, l, options)%% Calculates the best local ridge regression parameters using% one of a number of model selection criteria. Uses an initial% guess (l), a termination condition (t) and a hard limit to% the number of iterations (s) (all of which have defaults).%% Inputs%% F full design matrix (p-by-M)% Y input trainig data (p-by-n)% l initial guess at lambdas (1-by-M or scalar) (default all 0.01)% options options (string)%% Outputs%% l final estimate for lambdas (1-by-M) with m finite entries% V final estimate for model selection score (scalar)% H final design matrix (p-by-m)% A final partial covariance (m-by-m)% W final partial weight matrix (m-by-n)% P final projection matrix (p-by-p)%% The two model selection criteria that can be used are:%% UEV Unbiased Estimate of Variance% GCV Generalised Cross-Validation% defaultsVerbose = 0;Flops = 0;Random = 0;Term = 'g';Threshold = 1000;Hard = 100;% process optionsif nargin > 3 % initialise i = 1; [arg, i] = getNextArg(options, i); % scan through arguments while ~isempty(arg) if strcmp(arg, '-v') % verbose output required Verbose = 1; elseif strcmp(arg, '-V') % verbose output required with compute cost reporting Verbose = 1; Flops = 1; elseif strcmp(arg, '-r') % random ordering Random = 1; elseif strcmp(arg, '-h') % hard limit to specify [arg, i] = getNextArg(options, i); hl = str2num(arg); if ~isempty(hl) if hl > 1 Hard = round(hl); else fprintf('localRidge: hard limit should be positive\n') error('localRidge: bad value in -h option') end else fprintf('localRidge: value needed for hard limit\n') error('localRidge: missing value in -h option') end elseif strcmp(arg, '-t') % specify termination criterion [arg, ii] = getNextArg(options, i); method_given = 1; if strcmp(lower(arg), 'uev') % use UEV (unbiased expected variance) to terminate Term = 'u'; elseif strcmp(lower(arg), 'gcv') % use GCV (generalised cross-validation) to terminate Term = 'g'; else % the method wasn't specified, or specified incorrectly method_given = 0; end if method_given % skip to next argument i = ii; [arg, ii] = getNextArg(options, i); end % is a number given? number_given = 1; nu = str2num(arg); value_given = 1; good_value = 1; if ~isempty(nu) % a value has been specified if nu > 0 & nu < 1 Threshold = nu; elseif nu >= 1 Threshold = round(nu); else good_value = 0; end else value_given = 0; end if value_given i = ii; end % error conditions if ~method_given & ~value_given fprintf('localRidge: terminate with UEV or GCV and give a threshold\n') error('localRidge: missing arguments for -t option') elseif value_given & ~good_value fprintf('localRidge: acceptable thresholds are\n') fprintf(' between 0 and 1 (absolute change)\n') fprintf(' greater than 1 (relative change)\n') error('localRidge: bad value for -t option') end else fprintf('%s\n', options) for k = 1:i-length(arg)-1 fprintf(' '); end for k = 1:length(arg) fprintf('^'); end fprintf('\n') error('localRidge: unrecognised option') end % get next argument [arg, i] = getNextArg(options, i); endendif nargin < 3 l = 0.01; % default initial guessend% initialise[p1, M] = size(F);[p2, n] = size(Y);if p1 ~= p2 error('localRidge: inconsistent design matrix and training outputs')else p = p1;end[l1, l2] = size(l);if l1 == 1 if l2 == 1 l = l * ones(1,M); elseif l2 ~= M error('localRidge: lambda list inconsistent length') endelseif l2 == 1 if l1 ~= M error('localRidge: lambda list inconsistent length') else l = l'; endendif Verbose fprintf('localRidge\n');endif Flops flops(0);endsweep = 1;done = 0;keep = find(l ~= Inf);if Flops fprintf('H...'); endH = F(:,keep);if Flops fprintf('HH...'); endHH = H' * H;if Flops fprintf('A...'); endA = inv(HH + diag(l(keep)));if Flops fprintf('HA...'); endHA = H * A;if Flops fprintf('W...'); endW = HA' * Y;if Flops fprintf('P...'); endP = eye(p) - HA * H';if Flops fprintf('PY...'); endPY = P * Y;if Flops fprintf('trP...'); endtrP = trace(P);if Flops fprintf('V...'); endif Term == 'u' old_V = traceProduct(PY', PY) / trP;elseif Term == 'f' old_V = (2 * p - trP) * traceProduct(PY', PY) / (p * trP);elseif Term == 'g' old_V = p * traceProduct(PY', PY) / trP^2;elseif Term == 'b' old_V = (p + (log(p) - 1) * (p - trP)) * traceProduct(PY', PY) / (p * trP);else PYP = P * Y ./ dupCol(diag(P), n); old_V = traceProduct(PYP', PYP) / p;endif Flops fprintf('(%d)\n', flops); endif Verbose fprintf('pass in out ') if Term == 'u' fprintf(' UEV ') else fprintf(' GCV ') end fprintf(' change ') if Flops fprintf(' flops ') end fprintf('\n') fprintf('%4d %4d %4d %9.3e - ', ... 0, length(keep), M-length(keep), old_V); if Flops fprintf('%8d\n', flops); else fprintf('\n'); endend% outer loopwhile ~done % obtain order in which to update lambdas if Random % random order list = rand(1,M); [olst ilst] = sort(list); else % same order every time (left to right) ilst = 1:M; end % inner loop num_in = 0; num_out = 0; for j = ilst % optimise this j [V, lj, A, HA, W, PY, trP] = localRidgeJ(j, F, l, A, HA, W, PY, trP, Term); % count changes if lj == Inf num_out = num_out + 1; else num_in = num_in + 1; end % update l(j) = lj; end % done yet? if sweep >= Hard if Verbose fprintf('%4d %4d %4d %9.3e %7d ', sweep, num_in, num_out, V, change); if Flops fprintf('%8d\n', flops); else fprintf('\n'); end fprintf('hard limit reached\n') end done = 1; elseif Threshold > 1 % use relative change in score change = round(old_V / (old_V - V)); if Verbose fprintf('%4d %4d %4d %9.3e %7d ', sweep, num_in, num_out, V, change); if Flops fprintf('%8d\n', flops); else fprintf('\n'); end end if change > Threshold done = 1; if Verbose fprintf('relative threshold crossed\n') end else old_V = V; end else % use absolute change in score change = old_V - V; if Verbose fprintf('%4d %4d %4d %9.3e %7d ', sweep, num_in, num_out, V, change); if Flops fprintf('%8d\n', flops); else fprintf('\n'); end end if change < Threshold done = 1; if Verbose fprintf('absolute threshold crossed\n') end else old_V = V; end end % update sweep = sweep + 1;end% outputsif nargout > 2 subset = find(l ~= Inf); H = F(:,subset); if nargout > 3 % A known already from localRidgeJ if nargout > 4 % W known already from localRidgeJ if nargout > 5 P = eye(p) - H * A * H'; end end endend
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -