📄 lwpr.m
字号:
rf = template_rf;else rf.D = lwprs(ID).init_D; rf.M = lwprs(ID).init_M; rf.alpha = lwprs(ID).init_alpha; rf.b0 = y; % the weighted mean of outputend% if more than univariate input, start with two projections such that% we can compare the reduction of residual error between two projectionsn_in = lwprs(ID).n_in;n_out = lwprs(ID).n_out;if (n_in > 1) n_reg = 2;else n_reg = 1;endrf.B = zeros(n_reg,lwprs(ID).n_out); % the regression parametersrf.c = c; % the center of the RFrf.SXresYres = zeros(n_reg,n_in); % needed to compute projectionsrf.ss2 = ones(n_reg,1)/lwprs(ID).init_P; % variance per projectionrf.SSYres = zeros(n_reg,n_out); % needed to compute linear modelrf.SSXres = zeros(n_reg,n_in); % needed to compute input reductionrf.W = eye(n_reg,n_in); % matrix of projections vectorsrf.Wnorm = zeros(n_reg,1); % normalized projection vectorsrf.U = eye(n_reg,n_in); % reduction of input spacerf.H = zeros(n_reg,n_out); % trace matrixrf.r = zeros(n_reg,1); % trace vectorrf.h = zeros(size(rf.alpha)); % a memory term for 2nd order gradientsrf.b = log(rf.alpha+1.e-10); % a memory term for 2nd order gradientsrf.sum_w = ones(n_reg,1)*1.e-10; % the sum of weightsrf.sum_e_cv2 = zeros(n_reg,1); % weighted sum of cross.valid. err. per dimrf.sum_e2 = 0; % weighted sum of error (not CV)rf.n_data = ones(n_reg,1)*1.e-10; % discounted amount of data in RFrf.trustworthy = 0; % indicates statistical confidencerf.lambda = ones(n_reg,1)*lwprs(ID).init_lambda; % forgetting raterf.mean_x = zeros(n_in,1); % the weighted mean of inputsrf.var_x = zeros(n_in,1); % the weighted variance of inputsrf.w = 0; % store the last computed weightrf.s = zeros(n_reg,1); % store the projection of inputs%-----------------------------------------------------------------------------function w=compute_weight(diag_only,kernel,c,D,x)% compute the weight% subtract the centerx = x-c;if diag_only, d2 = x'*(diag(D).*x);else, d2 = x'*D*x;endswitch kernel case 'Gaussian' w = exp(-0.5*d2); case 'BiSquare' if (0.5*d2 > 1) w = 0; else w = (1-0.5*d2)^2; endend%-----------------------------------------------------------------------------function [rf,xmz,ymz]=update_means(rf,x,y,w)% update means and computer mean zero variablesrf.mean_x = (rf.sum_w(1)*rf.mean_x*rf.lambda(1) + w*x)/(rf.sum_w(1)*rf.lambda(1)+w);rf.var_x = (rf.sum_w(1)*rf.var_x*rf.lambda(1) + w*(x-rf.mean_x).^2)/(rf.sum_w(1)*rf.lambda(1)+w);rf.b0 = (rf.sum_w(1)*rf.b0*rf.lambda(1) + w*y)/(rf.sum_w(1)*rf.lambda(1)+w);xmz = x - rf.mean_x;ymz = y - rf.b0;%-----------------------------------------------------------------------------function [rf,yp,e_cv,e] = update_regression(rf,x,y,w)% update the linear regression parameters[n_reg,n_in] = size(rf.W);n_out = length(y);% compute the projection[rf.s,xres] = compute_projection(x,rf.W,rf.U);% compute all residual errors and targets at all projection stagesyres = rf.B .* (rf.s*ones(1,n_out));for i=2:n_reg yres(i,:) = yres(i,:) + yres(i-1,:);endyres = ones(n_reg,1)*y' - yres;e_cv = yres;ytarget = [y';yres(1:n_reg-1,:)];% update the projectionslambda_slow = 1-(1-rf.lambda)/10;rf.SXresYres = rf.SXresYres .* (lambda_slow*ones(1,n_in)) + w * (sum(ytarget,2)*ones(1,n_in)).*xres;rf.Wnorm = sqrt(sum(rf.SXresYres.^2,2));rf.W = rf.SXresYres ./ (rf.Wnorm * ones(1,n_in));% update sufficient statistics for regressionsrf.ss2 = rf.lambda .* rf.ss2 + rf.s.^2 * w;rf.SSYres = (rf.lambda*ones(1,n_out)) .* rf.SSYres + w * ytarget .* ... (rf.s*ones(1,n_out));rf.SSXres = (rf.lambda*ones(1,n_in)) .* rf.SSXres + w * (rf.s*ones(1,n_in)).* xres;% update the regression and input reduction parametersrf.B = rf.SSYres ./ (rf.ss2*ones(1,n_out));rf.U = rf.SSXres ./ (rf.ss2*ones(1,n_in));% the new predicted output after updating[rf.s,xres] = compute_projection(x,rf.W,rf.U);yp = rf.B' * rf.s;e = y - yp;yp = yp + rf.b0;% is the RF trustworthy: a simple data countif (rf.n_data > n_in*2) rf.trustworthy = 1;end%-----------------------------------------------------------------------------function [rf,transient_multiplier] = update_distance_metric(ID,rf,x,y,w,e_cv,e,xn)global lwprs;% update the distance metricpenalty = lwprs(ID).penalty/length(x); % normalize penality w.r.t. number of inputsmeta = lwprs(ID).meta;meta_rate = lwprs(ID).meta_rate;kernel = lwprs(ID).kernel;diag_only = lwprs(ID).diag_only;% an indicator vector in how far individual projections are trustworthy% based on how much data the projection has been trained onderivative_ok = rf.n_data > 0.1./(1.-rf.lambda);if ~derivative_ok(1), transient_multiplier = 0; return;end% useful pre-computations: they need to come before the updatess = rf.s;e_cv2 = sum(e_cv.^2,2);e2 = e'*e;rf.sum_e_cv2 = rf.sum_e_cv2.*rf.lambda + w*e_cv2;rf.sum_e2 = rf.sum_e2*rf.lambda(1) + w*e2;e_cv = e_cv(end,:)';e_cv2 = e_cv'*e_cv;h = w*sum(s.^2./rf.ss2.*derivative_ok);W = rf.sum_w(1)*rf.lambda(1) + w;E = rf.sum_e_cv2(end);transient_multiplier = (rf.sum_e2/(rf.sum_e_cv2(end)+1.e-10))^4; % this is a numerical safety heuristicn_out = length(y);% the derivative dJ1/dwPs = s./rf.ss2.*derivative_ok; % zero the terms with insufficient data supportPse = Ps*e';dJ1dw = -E/W^2 + 1/W*(e_cv2 - sum(sum((2*Pse).*rf.H)) - sum((2*Ps.^2).*rf.r));% the derivatives dw/dM and dJ2/dM[dwdM,dJ2dM,dwwdMdM,dJ2J2dMdM] = dist_derivatives(w,rf,xn-rf.c,diag_only,kernel,penalty,meta);% the final derivative becomes (note this is upper triangular)dJdM = dwdM*dJ1dw/n_out + w/W*dJ2dM;% the second derivative if meta learning is required, and meta learning updateif (meta) % second derivatives dJ1J1dwdw = -e_cv2/W^2 - 2/W*sum(sum((-Pse/W -2*Ps*(s'*Pse)).*rf.H)) + 2/W*e2*h/w - ... 1/W^2*(e_cv2-2*sum(sum(Pse.*rf.H))) + E/W^3; dJJdMdM = (dwwdMdM*dJ1dw + dwdM.^2*dJ1J1dwdw)/n_out + w/W*dJ2J2dMdM; % update the learning rates aux = meta_rate * transient_multiplier * (dJdM.*rf.h); % limit the update rate ind = find(abs(aux) > 0.1); if (~isempty(ind)), aux(ind) = 0.1*sign(aux(ind)); end rf.b = rf.b - aux; % prevent numerical overflow ind = find(abs(rf.b) > 10); if (~isempty(ind)), rf.b(ind) = 10*sign(rf.b(ind)); end rf.alpha = exp(rf.b); aux = 1 - (rf.alpha.*dJJdMdM) * transient_multiplier ; ind = find(aux < 0); if (~isempty(ind)), aux(ind) = 0; end rf.h = rf.h.*aux - (rf.alpha.*dJdM) * transient_multiplier; end% update the distance metric, use some caution for too large gradientsmaxM = max(max(abs(rf.M)));delta_M = rf.alpha.*dJdM*transient_multiplier;ind = find(delta_M > 0.1*maxM);if (~isempty(ind)), rf.alpha(ind) = rf.alpha(ind)/2; delta_M(ind) = 0; disp(sprintf('Reduced learning rate'));endrf.M = rf.M - rf.alpha.*dJdM*transient_multiplier;rf.D = rf.M'*rf.M;% update sufficient statistics: note this must come after the updates and% is conditioned on that sufficient samples contributed to the derivativeH = (rf.lambda*ones(1,n_out)).*rf.H + (w/(1-h))*s*e_cv'*transient_multiplier;r = rf.lambda.*rf.r + (w^2*e_cv2/(1-h))*(s.^2)*transient_multiplier;rf.H = (derivative_ok*ones(1,n_out)).*H + (1-(derivative_ok*ones(1,n_out))).*rf.H;rf.r = derivative_ok.*r + (1-derivative_ok).*rf.r;%-----------------------------------------------------------------------------function [dwdM,dJ2dM,dwwdMdM,dJ2J2dMdM] = dist_derivatives(w,rf,dx,diag_only,kernel,penalty,meta)% compute derivatives of distance metric: note that these will be upper% triangular matrices for efficiencyn_in = length(dx);dwdM = zeros(n_in);dJ2dM = zeros(n_in);dJ2J2dMdM = zeros(n_in);dwwdMdM = zeros(n_in);for n=1:n_in, for m=n:n_in, sum_aux = 0; sum_aux1 = 0; % take the derivative of D with respect to nm_th element of M */ if (diag_only & n==m), aux = 2*rf.M(n,n); dwdM(n,n) = dx(n)^2 * aux; sum_aux = rf.D(n,n)*aux; if (meta) sum_aux1 = sum_aux1 + aux^2; end elseif (~diag_only), for i=n:n_in, % aux corresponds to the in_th (= ni_th) element of dDdm_nm % this is directly processed for dwdM and dJ2dM if (i == m) aux = 2*rf.M(n,i); dwdM(n,m) = dwdM(n,m) + dx(i) * dx(m) * aux; sum_aux = sum_aux + rf.D(i,m)*aux; if (meta) sum_aux1 = sum_aux1 + aux^2; end else aux = rf.M(n,i); dwdM(n,m) = dwdM(n,m) + 2. * dx(i) * dx(m) * aux; sum_aux = sum_aux + 2.*rf.D(i,m)*aux; if (meta) sum_aux1 = sum_aux1 + 2*aux^2; end end end end switch kernel case 'Gaussian' dwdM(n,m) = -0.5*w*dwdM(n,m); case 'BiSquare' dwdM(n,m) = -sqrt(w)*dwdM(n,m); end dJ2dM(n,m) = 2.*penalty*sum_aux; if (meta) dJ2J2dMdM(n,m) = 2.*penalty*(2*rf.D(m,m) + sum_aux1); dJ2J2dMdM(m,n) = dJ2J2dMdM(n,m); switch kernel case 'Gaussian' dwwdMdM(n,m) = dwdM(n,m)^2/w - w*dx(m)^2; case 'BiSquare' dwwdMdM(n,m) = dwdM(n,m)^2/w/2 - 2*sqrt(w)*dx(m)^2; end dwwdMdM(m,n) = dwwdMdM(n,m); end endend%-----------------------------------------------------------------------------function [s,xres] = compute_projection(x,W,U)% recursively compute the projected input[n_reg,n_in] = size(W);s = zeros(n_reg,1);for i=1:n_reg, xres(i,:) = x'; s(i) = W(i,:)*x; x = x - U(i,:)'*s(i);end%-----------------------------------------------------------------------------function [rf] = check_add_projection(ID,rf)% checks whether a new projection needs to be added to the rfglobal lwprs;[n_reg,n_in] = size(rf.W);if (n_reg >= n_in) return;end% here, the mean squared error of the current regression dimension% is compared against the previous one. Only if there is a signficant% improvement in MSE, another dimension gets added. Some additional% heuristics had to be added to ensure that the MSE decision is % based on sufficient data */mse_n_reg = rf.sum_e_cv2(n_reg) / rf.sum_w(n_reg) + 1.e-10;mse_n_reg_1 = rf.sum_e_cv2(n_reg-1)/ rf.sum_w(n_reg-1) + 1.e-10;if (mse_n_reg/mse_n_reg_1 < lwprs(ID).add_threshold & ... rf.n_data(n_reg)/rf.n_data(1) > 0.99 & ... rf.n_data(n_reg)*(1.-rf.lambda(n_reg)) > 0.5), sprintf('add a dimension'); rf.B = [rf.B; zeros(1,n_out)]; rf.SXresYres = [rf.SXresYres; zeros(1,n_in)]; rf.ss2 = [rf.ss2;1/lwprs(ID).init_P]; rf.SSYres = [rf.SSYres; zeros(1,n_out)]; rf.SSXres = [rf.SSXres; zeros(1,n_in)]; rf.W = [rf.W; zeros(1,n_in)]; rf.W(n_reg+1,n_reg+1) = 1; rf.Wnorm = [rf.Wnorm; 0]; rf.U = [rf.U; zeros(1,n_in)]; rf.U(n_reg+1,n_reg+1) = 1; rf.H = [rf.H; zeros(1,n_out)]; rf.r = [rf.r; 0]; rf.sum_w = [rf.sum_w; 1.e-10]; rf.sum_e_cv2 = [rf.sum_e_cv2; 0]; rf.n_data = [rf.n_data; 0]; rf.lambda = [rf.lambda; lwprs(ID).init_lambda]; rf.s = [rf.s; 0];end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -