📄 lwpr.m
字号:
function [varargout] = lwpr(action,varargin)% lwpr implements the LWPR algorithm as suggested in % Vijayakumar, S. & Schaal, S. (2003). Incremental Online Learning% in High Dimensions. submitted.% Depending on the keyword in the input argument "action", a certain% number of inputs arguments will be parsed from "vargin". A variable% number of arguments are returned according to the "action".% See Matlab file for explanations how to use the different modalitiesn_data% of the program.%% Note: this implementation does not implement ridge regression. Newer% algorithms like LWPR and LWPPLS are much more suitable for% high dimensional data sets and data with ill conditioned % regression matrices.%% Copyright Sethu Vijayakumar and Stefan Schaal, September 2002% --------------- Different Actions of the program ------------------------% Initialize an LWPR model:%% FORMAT lwpr('Init',ID, n_in, n_out, diag_only, meta, meta_rate, ...% penalty, init_alpha, norm, name)% ID : desired ID of model% n_in : number of input dimensions% n_out : number of output dimensions% diag_only : 1/0 to update only the diagonal distance metric% meta : 1/0 to allow the use of a meta learning parameter% meta_rate : the meta learning rate% penalty : a smoothness bias, usually a pretty small number (1.e-4)% init_alpha : the initial learning rates% norm : the normalization of the inputs% norm_out : the normalization of the outputs% name : a name for the model%% alternatively, the function is called as%% FORMAT ID = lwpr('Init',ID,lwpr,)% lwpr : a complete data structure of a LWPR model%% returns nothing% Change a parameter of an LWPR model:%% FORMAT rc = lwpr('Change',ID,pname,value)% ID : lwpr data structure% pname : name of parameter to be changed% value : new parameter value%% returns nothing% Update an LWPR model with new data:%% FORMAT [yp,w] = lwpr('Update',ID,x,y)% ID : lwpr data structure% x : input data point% y : output data point%% Note: the following inputs are optional in order to use LWPR% in adaptive learning control with composite update laws% e : the tracking error of the control system% alpha : a strictly positive scalar to determine the% magnitude of the contribution to the update%% returns the LWPR data structure in lwpr, the prediction after% the update, yp, and the weight of the maximally activated weight% Predict an output for a LWPR model%% FORMAT [yp,w] = lwpr('Predict',ID,x)% ID : lwpr data structure% x : input data point% cutoff : minimal activation for prediction%% returns the prediction yp and the weight of the maximally activated weight% Return the data structure of a LWPR model%% FORMAT [lwpr] = lwpr('Structure',ID)% ID : lwpr data structure%% returns the complete data structure of a LWPR model, e.g., for saving or% inspecting it% Clear the data structure of a LWPR model%% FORMAT lwpr('Clear',ID)% ID : lwpr data structure%% returns nothing% the structure storing all LWPR modelsglobal lwprs;if nargin < 2, error('Incorrect call to lwpr');endswitch action, %.............................................................................. % Initialize a new LWPR model case 'Init' % check whether a complete model was % given or data for a new model if (nargin == 3) ID = varargin{1}; lwprs(ID) = varargin{2}; else % copy from input arguments ID = varargin{1}; lwprs(ID).n_in = varargin{2}; lwprs(ID).n_out = varargin{3}; lwprs(ID).diag_only = varargin{4}; lwprs(ID).meta = varargin{5}; lwprs(ID).meta_rate = varargin{6}; lwprs(ID).penalty = varargin{7}; lwprs(ID).init_alpha = varargin{8}; lwprs(ID).norm = varargin{9}; lwprs(ID).norm_out = varargin{10}; lwprs(ID).name = varargin{11}; % add additional convenient variables lwprs(ID).n_data = 0; lwprs(ID).w_gen = 0.1; lwprs(ID).w_prune = 0.9; lwprs(ID).init_lambda = 0.999; lwprs(ID).final_lambda = 0.9999; lwprs(ID).tau_lambda = 0.99999; lwprs(ID).init_P = 1.e+10; lwprs(ID).n_pruned = 0; lwprs(ID).add_threshold= 0.5; % other variables lwprs(ID).init_D = eye(lwprs(ID).n_in)*25; lwprs(ID).init_M = chol(lwprs(ID).init_D); lwprs(ID).init_alpha = ones(lwprs(ID).n_in)*lwprs(ID).init_alpha; lwprs(ID).mean_x = zeros(lwprs(ID).n_in,1); lwprs(ID).var_x = zeros(lwprs(ID).n_in,1); lwprs(ID).rfs = []; lwprs(ID).kernel = 'Gaussian'; end %.............................................................................. case 'Change' ID = varargin{1}; command = sprintf('lwprs(%d).%s = varargin{3};',ID,varargin{2}); eval(command); % make sure some initializations remain correct lwprs(ID).init_M = chol(lwprs(ID).init_D); %.............................................................................. case 'Update' ID = varargin{1}; x = varargin{2}; y = varargin{3}; if (nargin > 4) composite_control = 1; e = varargin{4}; alpha = varargin{5}; else composite_control = 0; end % update the global mean and variance of the training data for % information purposes lwprs(ID).mean_x = (lwprs(ID).mean_x*lwprs(ID).n_data + x)/(lwprs(ID).n_data+1); lwprs(ID).var_x = (lwprs(ID).var_x*lwprs(ID).n_data + (x-lwprs(ID).mean_x).^2)/(lwprs(ID).n_data+1); lwprs(ID).n_data = lwprs(ID).n_data+1; % normalize the inputs xn = x./lwprs(ID).norm; % normalize the outputs yn = y./lwprs(ID).norm_out; % check all RFs for updating % wv is a vector of 3 weights, ordered [w; sec_w; max_w] % iv is the corresponding vector containing the RF indices wv = zeros(3,1); iv = zeros(3,1); yp = zeros(size(y)); sum_w = 0; tms = zeros(length(lwprs(ID).rfs)); for i=1:length(lwprs(ID).rfs), % compute the weight and keep the three larget weights sorted w = compute_weight(lwprs(ID).diag_only,lwprs(ID).kernel,lwprs(ID).rfs(i).c,lwprs(ID).rfs(i).D,xn); lwprs(ID).rfs(i).w = w; wv(1) = w; iv(1) = i; [wv,ind]=sort(wv); iv = iv(ind); % only update if activation is high enough if (w > 0.001), rf = lwprs(ID).rfs(i); % update weighted mean for xn and y, and create mean-zero % variables [rf,xmz,ymz] = update_means(lwprs(ID).rfs(i),xn,yn,w); % update the regression [rf,yp_i,e_cv,e] = update_regression(rf,xmz,ymz,w); if (rf.trustworthy), yp = w*yp_i + yp; sum_w = sum_w + w; end % update the distance metric [rf,tm] = update_distance_metric(ID,rf,xmz,ymz,w,e_cv,e,xn); tms(i) = 1; % check whether a projection needs to be added check_add_projection(ID,rf); % update simple statistical variables rf.sum_w = rf.sum_w.*rf.lambda + w; rf.n_data = rf.n_data.*rf.lambda + 1; rf.lambda = lwprs(ID).tau_lambda * rf.lambda + lwprs(ID).final_lambda*(1.-lwprs(ID).tau_lambda); % incorporate updates lwprs(ID).rfs(i) = rf; else lwprs(ID).rfs(i).w = 0; end % if (w > 0.001) end % if LWPR is used for control, incorporate the tracking error if (composite_control), inds = find(tms > 0); if ~isempty(inds), for j=1:length(inds), i = inds(j); lwprs(ID).rfs(i).B = lwprs(ID).rfs(i).B + alpha * tms(j)./ lwprs(ID).rfs(i).ss2 * lwprs(ID).rfs(i).w/sum_w .* (xn-lwprs(ID).rfs(i).c) * e; lwprs(ID).rfs(i).b0 = lwprs(ID).rfs(i).b0 + alpha * tms(j) / lwprs(ID).rfs(i).sum_w(1) * lwprs(ID).rfs(i).w/sum_w * e; end end end % do we need to add a new RF? if (wv(3) <= lwprs(ID).w_gen), if (wv(3) > 0.1*lwprs(ID).w_gen & lwprs(ID).rfs(iv(3)).trustworthy), lwprs(ID).rfs(length(lwprs(ID).rfs)+1)=init_rf(ID,lwprs(ID).rfs(iv(3)),xn,yn); else if (length(lwprs(ID).rfs)==0), lwprs(ID).rfs = init_rf(ID,[],xn,y); else lwprs(ID).rfs(length(lwprs(ID).rfs)+1) = init_rf(ID,[],xn,yn); end end end % do we need to prune a RF? Prune the one with smaller D if (wv(2:3) > lwprs(ID).w_prune), if (sum(sum(lwprs(ID).rfs(iv(2)).D)) > sum(sum(lwprs(ID).rfs(iv(3)).D))) lwprs(ID).rfs(iv(2)) = []; disp(sprintf('%d: Pruned #RF=%d',ID,iv(2))); else lwprs(ID).rfs(iv(3)) = []; disp(sprintf('%d: Pruned #RF=%d',ID,iv(3))); end lwprs(ID).n_pruned = lwprs(ID).n_pruned + 1; end % the final prediction if (sum_w > 0), yp = yp.*lwprs(ID).norm_out/sum_w; end varargout(1) = {yp}; varargout(2) = {wv(3)}; %.............................................................................. case 'Predict' ID = varargin{1}; x = varargin{2}; cutoff = varargin{3}; % normalize the inputs xn = x./lwprs(ID).norm; % maintain the maximal activation max_w = 0; yp = zeros(lwprs(ID).n_out,1); sum_w = 0; for i=1:length(lwprs(ID).rfs), % compute the weight w = compute_weight(lwprs(ID).diag_only,lwprs(ID).kernel,lwprs(ID).rfs(i).c,lwprs(ID).rfs(i).D,xn); lwprs(ID).rfs(i).w = w; max_w = max([max_w,w]); % only predict if activation is high enough if (w > cutoff & lwprs(ID).rfs(i).trustworthy), % the mean zero input xmz = xn - lwprs(ID).rfs(i).mean_x; % compute the projected inputs s = compute_projection(xmz,lwprs(ID).rfs(i).W,lwprs(ID).rfs(i).U); % the prediction yp = yp + (lwprs(ID).rfs(i).B'*s + lwprs(ID).rfs(i).b0) * w; sum_w = sum_w + w; end % if (w > cutoff) end % the final prediction if (sum_w > 0), yp = yp.*lwprs(ID).norm_out/sum_w; end varargout(1) = {yp}; varargout(2) = {max_w}; %.............................................................................. case 'Structure' ID = varargin{1}; varargout(1) = {lwprs(ID)}; %.............................................................................. case 'Clear' ID = varargin{1}; lwprs(ID) = []; end %-----------------------------------------------------------------------------function rf=init_rf(ID,template_rf,c,y)% initialize a local modelglobal lwprs;if ~isempty(template_rf),
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -