⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 trainwn.m

📁 新的神经网络算法源程序
💻 M
字号:
function [t,d,w,c,b, mseif, flopif, timeif] = ...	 trainwn(x, y, nbwavelon, max_epoch, initmode, min_nbw, levels, t,d,w,c,b)%Trainwn: Train the wavelet net%%	[t,d,w,c,b] = %	trainwn(x, y, nbwavelon, max_epoch, initmode, min_nbw, levels)%%	x: input patterns, each column is one input pattern%	y: output patterns, a row vector, each entry is one output pattern%	nbwavelon: number of wavelons used to construct the net%	max_epoch: maximum number of epochs of training%	initmode: Initialization mode, may be 0, 1, 2, 3, default is 2%		initmode=0:  supply the initial values by input arguments%		initmode=1:  residual based selection%		initmode=2:  stepwise by orthogonalization%		initmode=3:  backward elimination%	min_nbw: minimum number of input patterns each wavelon should "cover"%	levels: number of scale levels scanned during initialization%%	The arguments initmode, min_nbw and levels are optional.%%	t: translation parameters%	d: dilation parameters%	w: linear weights%	c: linear comnibation coefficients (direct connections)%	b: bias%	By Qinghua Zhang. March, 1994.wnetver;if (nargin < 2) | (nargin > 12)  error('Number of input arguments error. See HELP TRAINWN');end[xl, xc] = size(x);[yl, yc] = size(y);[nbvar, nbobs] = size(x);if xc ~= yc  error('x and y must have the same number of columns.');endif yl ~= 1  error('y must be a row vector.');end% ---If arguments not defined then make empty---if nargin < 3, nbwavelon = []; end;if nargin < 4, max_epoch = []; end;if nargin < 5, initmode = []; end;if nargin < 6, min_nbw = []; end;if nargin < 7, levels = []; end;%Default values%===============if isempty(max_epoch)  max_epoch = 0;endif isempty(initmode)  initmode = 2;end%% If nbwavelon=0, do only a linear regressionif nbwavelon==0  c = y / [x; ones(size(y))];  b = c(nbvar+1); c = c(1:nbvar); t=[]; d=[]; w=[]; errors=[];  disp('Trainwn: 0 wavelet in the network: Linear regression.');  returnend%% If nbwavelon=[] or < 0,  automatic network order determinationif isempty(nbwavelon), nbwavelon=-1; endif nbwavelon < 0  if isempty(min_nbw), min_nbw = 2+nbvar; end;  if isempty(levels), levels = 4; end;  %nbwavelon = fix((nbobs-nbvar-2) / (nbvar+2));enddisp_freq = 1; % Display frequencyclfhold offaxis('normal');if nbvar == 1  % Define  input vectors to graph function that network actually performs.  x2= -1:.01:1;end%Pre-processing of the data%===========================fprintf('   Data pre-processing...');% Start the timertime_on = clock; flops(0);[x,y,xcent,xhalf,ycent,yhalf] = prepdata(x,y);prepflops = flops;prep_time = etime(clock, time_on);fprintf('\n   Data pre-processing terminated.\n\n');% INITIALIZE NETWORK ARCHITECTURE%================================% Start the timertime_on = clock;flops(0); if initmode>=1 & initmode<=3  [t,d,w,c,b]=initwnet(initmode,nbwavelon,x,y,min_nbw,levels);elseif initmode==0 %Mode 0, supplied by arguments  if nargin ~= 12, error('Nb of input args must be 12 when InitMode=0.'); end  [tl, tc] = size(t);  [dl, dc] = size(d);  [wl, wc] = size(w);  [cl, cc] = size(c);  [bl, bc] = size(b);  % TEST consistancy of the dimensions  (for Mode 0 only)  %======================================================  if (tl~=dl)|(wl~=1)|(cl~=1)|(bl~=1)|(tc~=dc)| ...     (tc~=wc)|(cc~=tl)|(bc~=1)     error('Dimensions of the initial parameters error.');  end  % Rescaling (for Mode 0 only)  t = (t - xcent * ones(1,tc)) ./ (xhalf * ones(1,tc));  d = d .* (xhalf *  ones(1,tc));  w = w / yhalf;  c = (c .* xhalf') / yhalf;  b = (b - ycent + (yhalf * c ./ xhalf') * xcent) / yhalf;else  error('The input argument initmode error.');endinitflops = flops;global Pause_Time; init_time = etime(clock, time_on) - Pause_Time;%Do not count the pause time for initializationfprintf('\n   Initialization terminated.\n\n');fprintf('   Strike any key to CONTINUE.\n\n');pause% TRAIN THE NETWORK%==================if max_epoch >= 1  fprintf('   Training the network for %g epochs.\n\n', max_epoch);  fprintf('   Press ''T'' on the keyboard to terminate the training');  fprintf(' after the current epoch.\n\n');end% First PRESENTATION PHASEg = wavenet(x,t,d,w,c,b);E = y-g;SSE = sum(E .* E); %Sum of Square ErrorsMSE = SSE/nbobs;mseif = [MSE*yhalf*yhalf 0];stdy = std(y);   %Standard deviation of yNSRMSE = sqrt(SSE / nbobs) / stdy;%NSRMSE: Normalized Square Root of Mean Square Error% TRAINING RECORDerrors = [NSRMSE];% DISPLAY Initial errorsif ~(strcmp(computer,'PC') | strcmp(computer,'386'))  fprintf(['   epoch=%.0f' 9 'NSRMSE=%g      ' 9 'MSE=%g\n'], ...	  0,NSRMSE,MSE*yhalf*yhalf)endhold off;if nbvar == 1  % PLOT INITIAL FUNCTION APPROXIMATION  clf;  if nbobs<100    hddata=plot(x*xhalf+xcent, y*yhalf+ycent, 'o');  else    hddata=plot(x*xhalf+xcent, y*yhalf+ycent, '.');  end  hdapp=line('color','k','linestyle','-','erase','xor', ...        'xdata',x2*xhalf+xcent,'ydata',wavenet(x2,t,d,w,c,b)*yhalf+ycent);  xlabel('Input'); ylabel('Output: -,  Target: +');  title(['epoch=' num2str(epoch) ' NSRMSE=' num2str(NSRMSE) ...         ' MSE=' num2str(MSE*yhalf*yhalf)])  pause(0)elseif max_epoch >= 1  % Initialize the grapuics window to plot errors  clf; plot([0 max_epoch], [0 NSRMSE],'.');   drawnow; line('color','k', 'linestyle','o', 'xdata', 0, 'ydata', NSRMSE);  hderr=line('color','k','linestyle','-','erase','none', ...             'xdata',[0],'ydata',[NSRMSE]);  %errscale = 10^(fix(-log10(NSRMSE)+1));  %plot(0, NSRMSE, 'o');  %axis([0 max_epoch 0 ceil(NSRMSE*errscale)/errscale]);  title('Network Error');  xlabel('Epoch');  ylabel('NSRMSE');  pause(0)end% Start the timertime_on = clock; flops(0);for epoch=1:max_epoch  % CHECK PHASE  %if NSRMSE < err_goal, epoch=epoch-1; break, end  %Test 'T' to terminate the training  readkeybuffer = readkbuf;  if readkeybuffer=='t' | readkeybuffer=='T', epoch=epoch-1; break, end  % LEARNING PHASE  [t,d,w,c,b,SSE] = adaptlr(x,y,t,d,w,c,b, SSE);    MSE = SSE/nbobs;   % PRESENTATION PHASE  NSRMSE = sqrt(SSE / nbobs) / stdy;  % TRAINING RECORD  errors = [errors NSRMSE];  % DISPLAY PROGRESS  temp = flops;  if (rem(epoch,disp_freq) == 0)    if ~(strcmp(computer,'PC') | strcmp(computer,'386'))      fprintf(['   epoch=%.0f' 9 'NSRMSE=%g      ' 9 'MSE=%g\n'], ...	      epoch,NSRMSE,MSE*yhalf*yhalf)    end    if nbvar == 1      % PLOT current FUNCTION APPROXIMATION      set(hdapp, ...          'xdata',x2*xhalf+xcent,'ydata',wavenet(x2,t,d,w,c,b)*yhalf+ycent);      title(['epoch=' num2str(epoch) ' NSRMSE=' num2str(NSRMSE) ...             ' MSE=' num2str(MSE*yhalf*yhalf)])      pause(0)    else      lenerr = length(errors);      set(hderr,'xdata',[lenerr-2 lenerr-1],'ydata',errors([lenerr-1 lenerr]));      drawnow;    end  end  flops(temp);endhold off;trainflops = flops;train_time = etime(clock, time_on);if epoch > 0, disp([13 'Training terminated.']); end;if nbvar == 1  % PLOT FINAL APPROXIMATION  set(hdapp, ...          'xdata',x2*xhalf+xcent,'ydata',wavenet(x2,t,d,w,c,b)*yhalf+ycent);  title(['epoch=' num2str(epoch) ' NSRMSE=' num2str(NSRMSE) ...         ' MSE=' num2str(MSE*yhalf*yhalf)])  pause(0)  disp('What you see in the graphics window is the final result.');  disp('Strike any key to see further information.');  pauseendif epoch > 0  % PLOT ERROR CURVE  %=================  hold off  plot(0,errors(1),'o', 0:length(errors)-1,errors);  errscale = 10^(fix(-log10(max(errors))+1));  axis([0 length(errors)-1 0 ceil(max(errors)*errscale)/errscale]);  title('Network Error')  xlabel('Epoch')  ylabel('NSRMSE')  disp('Strike any key to see further information.');  pauseend% SUMMARIZE RESULTS%==================E = y-wavenet(x,t,d,w,c,b);SSE = sum(E .* E);MSE = SSE/nbobs;NSRMSE = sqrt(SSE / nbobs) / stdy;if nbobs <= 100  % PLOT Error BAR associated with each output vector  if nbvar == 1    [xsorted, xsind] = sort(x);    E = E(xsind);		% Sort E to x ascending order  end  bar(E*yhalf);    axis([1 nbobs  min(E)*1.1*yhalf max(E)*1.1*yhalf]);  title('Error for Each Sample');  xlabel('Samples');  ylabel('Errors');  %axis('auto');  disp('Strike any key to see further information.');  pauseelse  % PLOT Error LINE associated with each output vector  plot(E*yhalf);  axis([1 nbobs  min(E)*1.1*yhalf max(E)*1.1*yhalf]);  title('Error for Each Sample');  xlabel('Samples');  ylabel('Errors');  %axis('auto');  hold off  disp('Strike any key to see further information.');  pauseendhold off;%Post processing of the wavenet%===============================[t,d,w,c,b] = pospwnet(t,d,w,c,b,xcent,xhalf,ycent,yhalf);fprintf('\nData pre-processing took  %g seconds, %g flops.\n',  ...	prep_time, prepflops);fprintf('Initialization      took  %g seconds, %g flops.\n', ...	init_time, initflops);if epoch > 0  fprintf('Trained for %.0f epochs,  %g seconds.\n', epoch, train_time);  fprintf('Training took %.0f flops, ', trainflops);  fprintf('Average of %.0f flops/epoch.\n', (trainflops/epoch));else  fprintf('Trained for 0 epoch.\n');endfprintf('Final Normalized Square Root of Mean Square Error is %g.\n',NSRMSE);fprintf('Final Mean Square Errors is %g.\n',MSE*yhalf*yhalf);mseif(2) = MSE*yhalf*yhalf;flopif = [initflops trainflops];timeif = [init_time train_time];

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -