📄 globalridge.m
字号:
function [l, e, L, E] = globalRidge(H, Y, l, options, U)% [l, e, L, E] = globalRidge(H, Y, l, options, U)%% Calculates the best global ridge regression parameter (l) and% the corresponding predicted error (e) using one of a number of% prediction methods (UEV, FPE, GCV or BIC). Needs a design (H),% the training set outputs (Y), and an initial guess (l).% The termination criterion, maximum number of iterations,% verbose output and the use of a non-standard weight penalty% are controlled from the options string. The non-standard% metric, if used, is given in the fifth argument (U). L and E% return the evolution of the regularisation parameter and error% values from the initial to final iterations. If the input l is% a vector (more than one guess), a corresponding number of% answers will be returned, e will also be a vector and L and E% will be matrices (with each row corresponding to the iterations% resulting after each guess).%% Inputs%% H design matrix (p-by-m)% Y input trainig data (p-by-k)% l initial guess(es) at lambda (vector length q) (default 0.01)% options options (string)% U optional non-standard smoothing metric (m-by-m)% Outputs%% l final estimate(s) for lambda (1-by-q)% e final estimate(s) for model selection score (1-by-q)% L list(s) of running lambda values (n-by-q)% E list(s) of running error values (n-by-q)%% The various model selection criteria used are:%% UEV Unbiased Estimate of Variance% FPE Final Prediction Error% GCV Generalised Cross Validation% BIC Bayesian Information Criterion%% specified in options by, e.g. 'FPE'.% defaultsVerbose = 0;Flops = 0;Model = 'g';Threshold = 1000;Hard = 100;Standard = 1;% process optionsif nargin > 3 % initialise i = 1; [arg, i] = getNextArg(options, i); % scan through arguments while ~isempty(arg) if strcmp(arg, '-v') % verbose output required Verbose = 1; elseif strcmp(arg, '-V') % verbose output required with compute cost reporting Verbose = 1; Flops = 1; elseif strcmp(arg, '-U') % non-standard penalty matrix Standard = 0; elseif strcmp(arg, '-h') % hard limit to specify [arg, i] = getNextArg(options, i); hl = str2num(arg); if ~isempty(hl) if hl > 1 Hard = round(hl); else fprintf('globalRidge: hard limit should be positive\n') error('globalRidge: bad value in -h option') end else fprintf('globalRidge: value needed for hard limit\n') error('globalRidge: missing value in -h option') end elseif strcmp(arg, '-t') % termination criterion to specify [arg, i] = getNextArg(options, i); te = str2num(arg); if ~isempty(te) if te >= 1 Threshold = round(te); elseif te > 0 Threshold = te; else fprintf('globalRidge: threshold should be positive\n') error('globalRidge: bad value in -t option') end else fprintf('globalRidge: value needed for threshold\n') error('globalRidge: missing value in -t option') end elseif strcmp(lower(arg), 'uev') % use UEV (unbiased expected variance) Model = 'u'; elseif strcmp(lower(arg), 'fpe') % use FPE (final prediction error) Model = 'f'; elseif strcmp(lower(arg), 'gcv') % use GCV (generalised cross-validation) Model = 'g'; elseif strcmp(lower(arg), 'bic') % use BIC (Bayesian information criterion) Model = 'b'; else fprintf('%s\n', options) for k = 1:i-length(arg)-1 fprintf(' '); end for k = 1:length(arg) fprintf('^'); end fprintf('\n') error('globalRidge: unrecognised option') end % get next argument [arg, i] = getNextArg(options, i); endendif nargin < 3 l = 0.01; % default initial guessendif ~Standard if nargin < 5 fprintf('globalRidge: specify non-standard penalty matrix\n') error('globalRidge: -U option implies fifth argument') endelse U = 1;end% initialise[p, m] = size(H);[p, k] = size(Y);[q1, q2] = size(l);if q1 == 1 q = q2;elseif q2 == 1 q = q1;else error('globalRidge: list of guesses should be vector, not matrix')end[u1, u2] = size(U);if u1 == m & u2 == m % transform the problem - equivalent to U'*U metric H = H * inv(U);elseif u1 ~= 1 | u2 ~= 1 estr = sprintf('%d-by-%d', m, m); error(['globalRidge: U should be 1-by-1 or ' estr])endHH = H' * H;HY = H' * Y;e = zeros(1, q);if nargout > 2 L = zeros(Hard+1, q); endif nargout > 3 E = zeros(Hard+1, q); endmaxcount = 1;if Verbose fprintf('\nglobalRidge\n')endif Flops flops(0)end% loop through each guessfor i = 1:q if Verbose fprintf('pass ') fprintf(' lambda ') if Model == 'u' fprintf(' UEV ') elseif Model == 'f' fprintf(' FPE ') elseif Model == 'g' fprintf(' GCV ') else fprintf(' BIC ') end fprintf(' change ') if Flops fprintf(' flops\n') else fprintf('\n') end end notTooMany = 1; notDone = 1; count = 0; A = inv(HH + l(i) * eye(m)); g = m - l(i) * trace(A); PY = Y - H * (A * HY); YPY = traceProduct(PY', PY); if Model == 'u' psi = p / (p - g); elseif Model == 'f' psi = (p + g) / (p - g); elseif Model == 'g' psi = p^2 / (p - g)^2; else psi = (p + (log(p) - 1) * g) / (p - g); end e(i) = psi * YPY / p; if Verbose fprintf('%4d %9.3e %9.3e - ', count, l(i), e(i)) if Flops fprintf('%9d\n', flops) else fprintf('\n') end end if nargout > 2 L(1,i) = l(i); end if nargout > 3 E(1,i) = e(i); end % re-estimate til convergence or exhaustion of iterations while notDone & notTooMany % next iteration count = count + 1; % get some needed quantities A2 = A^2; A3 = A * A2; % re-estimate lambda if Model == 'u' eta = 1 / (2 * (p - g)); elseif Model == 'f' eta = p / ((p - g) * (p + g)); elseif Model == 'g' eta = 1 / (p - g); else eta = p * log(p) / (2 * (p - g) * (p + (log(p) - 1) * g)); end nl = eta * YPY * trace(A - l(i) * A2) / trace(HY' * A3 * HY); % store result if nargout > 2 L(count+1,i) = nl; end % calculate new model selection score A = inv(HH + nl * eye(m)); g = m - nl * trace(A); PY = Y - H * (A * HY); YPY = traceProduct(PY', PY); if Model == 'u' psi = p / (p - g); elseif Model == 'f' psi = (p + g) / (p - g); elseif Model == 'g' psi = p^2 / (p - g)^2; else psi = (p + (log(p) - 1) * g) / (p - g); end ns = psi * YPY / p; % store result if nargout > 3 E(count+1,i) = ns; end % what's the change if Threshold >= 1 % interpret threshold as one part in many change = round(abs(e(i) / (e(i) - ns))); else % interpret threshold as absolute difference change = abs(e(i) - ns); end % time to go home? if count >= Hard notTooMany = 0; elseif Threshold >= 1 % interpret threshold as one part in many if change > Threshold notDone = 0; end else % interpret threshold as absolute difference if change < Threshold notDone = 0; end end % get ready for next iteration (or end) l(i) = nl; e(i) = ns; if Verbose fprintf('%4d %9.3e %9.3e ', count, l(i), e(i)) if Threshold >=1 fprintf('%7d ', change) else fprintf('%7.1e ', change) end if Flops fprintf('%9d\n', flops) else fprintf('\n') end end end if Verbose if ~notTooMany fprintf('hard limit reached\n') else if Threshold >=1 fprintf('relative ') else fprintf('absolute ') end fprintf('threshold in ') if Model == 'u' fprintf('UEV ') elseif Model == 'f' fprintf('FPE ') elseif Model == 'g' fprintf('GCV ') else fprintf('BIC ') end fprintf('crossed\n') end end if count > maxcount maxcount = count; endend% truncate L and Sif nargout > 2 L = L(1:maxcount+1,:);endif nargout > 3 E = E(1:maxcount+1,:);end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -