📄 trainwn.m
字号:
function [t,d,w,c,b, mseif, flopif, timeif] = ... trainwn(x, y, nbwavelon, max_epoch, initmode, min_nbw, levels, t,d,w,c,b)%Trainwn: Train the wavelet net%% [t,d,w,c,b] = % trainwn(x, y, nbwavelon, max_epoch, initmode, min_nbw, levels)%% x: input patterns, each column is one input pattern% y: output patterns, a row vector, each entry is one output pattern% nbwavelon: number of wavelons used to construct the net% max_epoch: maximum number of epochs of training% initmode: Initialization mode, may be 0, 1, 2, 3, default is 2% initmode=0: supply the initial values by input arguments% initmode=1: residual based selection% initmode=2: stepwise by orthogonalization% initmode=3: backward elimination% min_nbw: minimum number of input patterns each wavelon should "cover"% levels: number of scale levels scanned during initialization%% The arguments initmode, min_nbw and levels are optional.%% t: translation parameters% d: dilation parameters% w: linear weights% c: linear comnibation coefficients (direct connections)% b: bias% By Qinghua Zhang. March, 1994.wnetver;if (nargin < 2) | (nargin > 12) error('Number of input arguments error. See HELP TRAINWN');end[xl, xc] = size(x);[yl, yc] = size(y);[nbvar, nbobs] = size(x);if xc ~= yc error('x and y must have the same number of columns.');endif yl ~= 1 error('y must be a row vector.');end% ---If arguments not defined then make empty---if nargin < 3, nbwavelon = []; end;if nargin < 4, max_epoch = []; end;if nargin < 5, initmode = []; end;if nargin < 6, min_nbw = []; end;if nargin < 7, levels = []; end;%Default values%===============if isempty(max_epoch) max_epoch = 0;endif isempty(initmode) initmode = 2;end%% If nbwavelon=0, do only a linear regressionif nbwavelon==0 c = y / [x; ones(size(y))]; b = c(nbvar+1); c = c(1:nbvar); t=[]; d=[]; w=[]; errors=[]; disp('Trainwn: 0 wavelet in the network: Linear regression.'); returnend%% If nbwavelon=[] or < 0, automatic network order determinationif isempty(nbwavelon), nbwavelon=-1; endif nbwavelon < 0 if isempty(min_nbw), min_nbw = 2+nbvar; end; if isempty(levels), levels = 4; end; %nbwavelon = fix((nbobs-nbvar-2) / (nbvar+2));enddisp_freq = 1; % Display frequencyclfhold offaxis('normal');if nbvar == 1 % Define input vectors to graph function that network actually performs. x2= -1:.01:1;end%Pre-processing of the data%===========================fprintf(' Data pre-processing...');% Start the timertime_on = clock; flops(0);[x,y,xcent,xhalf,ycent,yhalf] = prepdata(x,y);prepflops = flops;prep_time = etime(clock, time_on);fprintf('\n Data pre-processing terminated.\n\n');% INITIALIZE NETWORK ARCHITECTURE%================================% Start the timertime_on = clock;flops(0); if initmode>=1 & initmode<=3 [t,d,w,c,b]=initwnet(initmode,nbwavelon,x,y,min_nbw,levels);elseif initmode==0 %Mode 0, supplied by arguments if nargin ~= 12, error('Nb of input args must be 12 when InitMode=0.'); end [tl, tc] = size(t); [dl, dc] = size(d); [wl, wc] = size(w); [cl, cc] = size(c); [bl, bc] = size(b); % TEST consistancy of the dimensions (for Mode 0 only) %====================================================== if (tl~=dl)|(wl~=1)|(cl~=1)|(bl~=1)|(tc~=dc)| ... (tc~=wc)|(cc~=tl)|(bc~=1) error('Dimensions of the initial parameters error.'); end % Rescaling (for Mode 0 only) t = (t - xcent * ones(1,tc)) ./ (xhalf * ones(1,tc)); d = d .* (xhalf * ones(1,tc)); w = w / yhalf; c = (c .* xhalf') / yhalf; b = (b - ycent + (yhalf * c ./ xhalf') * xcent) / yhalf;else error('The input argument initmode error.');endinitflops = flops;global Pause_Time; init_time = etime(clock, time_on) - Pause_Time;%Do not count the pause time for initializationfprintf('\n Initialization terminated.\n\n');fprintf(' Strike any key to CONTINUE.\n\n');pause% TRAIN THE NETWORK%==================if max_epoch >= 1 fprintf(' Training the network for %g epochs.\n\n', max_epoch); fprintf(' Press ''T'' on the keyboard to terminate the training'); fprintf(' after the current epoch.\n\n');end% First PRESENTATION PHASEg = wavenet(x,t,d,w,c,b);E = y-g;SSE = sum(E .* E); %Sum of Square ErrorsMSE = SSE/nbobs;mseif = [MSE*yhalf*yhalf 0];stdy = std(y); %Standard deviation of yNSRMSE = sqrt(SSE / nbobs) / stdy;%NSRMSE: Normalized Square Root of Mean Square Error% TRAINING RECORDerrors = [NSRMSE];% DISPLAY Initial errorsif ~(strcmp(computer,'PC') | strcmp(computer,'386')) fprintf([' epoch=%.0f' 9 'NSRMSE=%g ' 9 'MSE=%g\n'], ... 0,NSRMSE,MSE*yhalf*yhalf)endhold off;if nbvar == 1 % PLOT INITIAL FUNCTION APPROXIMATION clf; if nbobs<100 hddata=plot(x*xhalf+xcent, y*yhalf+ycent, 'o'); else hddata=plot(x*xhalf+xcent, y*yhalf+ycent, '.'); end hdapp=line('color','k','linestyle','-','erase','xor', ... 'xdata',x2*xhalf+xcent,'ydata',wavenet(x2,t,d,w,c,b)*yhalf+ycent); xlabel('Input'); ylabel('Output: -, Target: +'); title(['epoch=' num2str(epoch) ' NSRMSE=' num2str(NSRMSE) ... ' MSE=' num2str(MSE*yhalf*yhalf)]) pause(0)elseif max_epoch >= 1 % Initialize the grapuics window to plot errors clf; plot([0 max_epoch], [0 NSRMSE],'.'); drawnow; line('color','k', 'linestyle','o', 'xdata', 0, 'ydata', NSRMSE); hderr=line('color','k','linestyle','-','erase','none', ... 'xdata',[0],'ydata',[NSRMSE]); %errscale = 10^(fix(-log10(NSRMSE)+1)); %plot(0, NSRMSE, 'o'); %axis([0 max_epoch 0 ceil(NSRMSE*errscale)/errscale]); title('Network Error'); xlabel('Epoch'); ylabel('NSRMSE'); pause(0)end% Start the timertime_on = clock; flops(0);for epoch=1:max_epoch % CHECK PHASE %if NSRMSE < err_goal, epoch=epoch-1; break, end %Test 'T' to terminate the training readkeybuffer = readkbuf; if readkeybuffer=='t' | readkeybuffer=='T', epoch=epoch-1; break, end % LEARNING PHASE [t,d,w,c,b,SSE] = adaptlr(x,y,t,d,w,c,b, SSE); MSE = SSE/nbobs; % PRESENTATION PHASE NSRMSE = sqrt(SSE / nbobs) / stdy; % TRAINING RECORD errors = [errors NSRMSE]; % DISPLAY PROGRESS temp = flops; if (rem(epoch,disp_freq) == 0) if ~(strcmp(computer,'PC') | strcmp(computer,'386')) fprintf([' epoch=%.0f' 9 'NSRMSE=%g ' 9 'MSE=%g\n'], ... epoch,NSRMSE,MSE*yhalf*yhalf) end if nbvar == 1 % PLOT current FUNCTION APPROXIMATION set(hdapp, ... 'xdata',x2*xhalf+xcent,'ydata',wavenet(x2,t,d,w,c,b)*yhalf+ycent); title(['epoch=' num2str(epoch) ' NSRMSE=' num2str(NSRMSE) ... ' MSE=' num2str(MSE*yhalf*yhalf)]) pause(0) else lenerr = length(errors); set(hderr,'xdata',[lenerr-2 lenerr-1],'ydata',errors([lenerr-1 lenerr])); drawnow; end end flops(temp);endhold off;trainflops = flops;train_time = etime(clock, time_on);if epoch > 0, disp([13 'Training terminated.']); end;if nbvar == 1 % PLOT FINAL APPROXIMATION set(hdapp, ... 'xdata',x2*xhalf+xcent,'ydata',wavenet(x2,t,d,w,c,b)*yhalf+ycent); title(['epoch=' num2str(epoch) ' NSRMSE=' num2str(NSRMSE) ... ' MSE=' num2str(MSE*yhalf*yhalf)]) pause(0) disp('What you see in the graphics window is the final result.'); disp('Strike any key to see further information.'); pauseendif epoch > 0 % PLOT ERROR CURVE %================= hold off plot(0,errors(1),'o', 0:length(errors)-1,errors); errscale = 10^(fix(-log10(max(errors))+1)); axis([0 length(errors)-1 0 ceil(max(errors)*errscale)/errscale]); title('Network Error') xlabel('Epoch') ylabel('NSRMSE') disp('Strike any key to see further information.'); pauseend% SUMMARIZE RESULTS%==================E = y-wavenet(x,t,d,w,c,b);SSE = sum(E .* E);MSE = SSE/nbobs;NSRMSE = sqrt(SSE / nbobs) / stdy;if nbobs <= 100 % PLOT Error BAR associated with each output vector if nbvar == 1 [xsorted, xsind] = sort(x); E = E(xsind); % Sort E to x ascending order end bar(E*yhalf); axis([1 nbobs min(E)*1.1*yhalf max(E)*1.1*yhalf]); title('Error for Each Sample'); xlabel('Samples'); ylabel('Errors'); %axis('auto'); disp('Strike any key to see further information.'); pauseelse % PLOT Error LINE associated with each output vector plot(E*yhalf); axis([1 nbobs min(E)*1.1*yhalf max(E)*1.1*yhalf]); title('Error for Each Sample'); xlabel('Samples'); ylabel('Errors'); %axis('auto'); hold off disp('Strike any key to see further information.'); pauseendhold off;%Post processing of the wavenet%===============================[t,d,w,c,b] = pospwnet(t,d,w,c,b,xcent,xhalf,ycent,yhalf);fprintf('\nData pre-processing took %g seconds, %g flops.\n', ... prep_time, prepflops);fprintf('Initialization took %g seconds, %g flops.\n', ... init_time, initflops);if epoch > 0 fprintf('Trained for %.0f epochs, %g seconds.\n', epoch, train_time); fprintf('Training took %.0f flops, ', trainflops); fprintf('Average of %.0f flops/epoch.\n', (trainflops/epoch));else fprintf('Trained for 0 epoch.\n');endfprintf('Final Normalized Square Root of Mean Square Error is %g.\n',NSRMSE);fprintf('Final Mean Square Errors is %g.\n',MSE*yhalf*yhalf);mseif(2) = MSE*yhalf*yhalf;flopif = [initflops trainflops];timeif = [init_time train_time];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -