📄 updatewn.m
字号:
function [t,d,w,c,b, hess_t, hess_d, dt, dd] ... = updatewn(x,y,t,d,w,c,b,lr, hess_t, hess_d, dt, dd)%Updatewn: Up date the wavelet net, one step according to the data x, y%% [t,d,w,c,b] = updatewn(x,y,t,d,w,c,b,lr)% x and y contain columns of data: x(nxK), y(1xK),% t,d,w,b contain current net parameters: t(nxN), d(nxN), w(1xN), b(1x1).% By Qinghua Zhang. April 10, 1992.if (nargin ~= 8) & (nargin ~= 12) error('Wrong number of input arguments.');endif (nargout ~= 5) & (nargout ~= 9) error('Wrong number of output arguments.');end[xl, xc] = size(x);[yl, yc] = size(y);[tl, tc] = size(t);[dl, d_c] = size(d);[wl, wc] = size(w);[cl, cc] = size(c);if (xl ~= tl) | (tl ~= dl) error('x, t and d should have the same number of lines.');endif (tc ~= d_c) | (d_c ~= wc) error('t d and w should have the same number of columns.');endif wl ~= 1 error('w should have only one line.');endif xc ~= yc error('x and y should have the same number of columns.');endif (xl ~= cc) | (cl ~= 1) error('c must be a row vector and length(c) = linenb(x).');endif sum(size(b)) ~= 2 error('b should be a scalar.');endif nargin == 8 % This is the first call for this epoch, % compute hessien and gradient if lr ~= 0 error('The first call of updatewn for each epoch must have lr=0.'); end %MEMORY pre-allocation and initialization dt = zeros(tl*tc, 1); % In long column format dd = zeros(tl*tc, 1); hess_t = zeros(tl*tc, tl*tc); hess_d = zeros(dl*d_c, dl*d_c); for j=1:xc [g, xt, xtd2] = wavenet(x(:,j), t, d, w, c, b); e = y(j) - g; dphi_xtd2 = wavedef(xtd2, tl, 1); dphi_xtd2_mat = ones(tl,1) * dphi_xtd2; dpsi_t = reshape( -2.0 * dphi_xtd2_mat .* d .* d .* xt, tl*tc, 1); dpsi_d = reshape( 2.0 * dphi_xtd2_mat .* d .* xt .* xt, tl*tc, 1); % Reshaped to long format dt = dt - e * dpsi_t; % long format gradient, w to be added later. dd = dd - e * dpsi_d; hess_t = hess_t + dpsi_t * dpsi_t'; % Hessien Matrices, w to be added later hess_d = hess_d + dpsi_d * dpsi_d'; end w_long = reshape(ones(tl,1) * w, 1, tl*tc); w_mat = w_long' * w_long; dt = w_long' .* dt; % Add w, gradient direction dd = w_long' .* dd; hess_t = w_mat .* hess_t; % Add w, Hessien hess_d = w_mat .* hess_d;end% Above is only for the first call of the present epochhess_t = hess_t + lr * eye(tl * tc); % To get better matrix conditionhess_d = hess_d + lr * eye(dl * d_c);%Test matrix conditionsif (rcond(hess_t) < 1e-7) | (rcond(hess_d) < 1e-7) b = []; % Indicate bad matrix condition; return;end delta_t = hess_t \ dt; % Newton direction.delta_d = hess_d \ dd;%PACK backdelta_t = reshape(delta_t, tl, tc);delta_d = reshape(delta_d, tl, tc);% Test excessive large step%if sum(sum(abs(delta_t) > 0.3 * d)) | sum(sum(abs(delta_d) > 0.3 * d))if sum(sum(abs(delta_t) > 0.3 * (ones(dl,d_c) ./ d))) | ... sum(sum(abs(delta_d) > 0.3 * (d - delta_d))) b = []; % Treated as bad matrix condition return;endt = t - delta_t;d = d - delta_d;%Rct = rcond(hess_t) %############### debugging use%Rcd = rcond(hess_d)%Linear partwavelonvalue = wavelon(x,t,d);alltolc = [wavelonvalue; x; ones(1,xc)]; %All to linear combinationw = y / alltolc;c = w(tc+1:tc+tl);b = w(tc+tl+1);w = w(1:wc);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -