📄 test1.m
字号:
% Demonstration of different neural network training algorithms used
% for curve fitting
close all
%---------- Generate training and test set ----------
clc
PHI=0:0.25:6;
Y=sin(PHI);
PHI1 = PHI(:,1:2:25);
Y1 = Y(:,1:2:25);
PHI2 = PHI(:,2:2:24);
Y2 = Y(:,2:2:24);
figure(1);
plot(PHI1,Y1,'+')
title('Data set for training')
drawnow
% ----------- Training the network ----------
k=5;
while k~=7,
k = menu('Choose one of the training algorithms below:',...
'Back Propagation (Batch)',...
'Back Propagation (Recursive)',...
'Recursive Prediction Error Method (forgetting factor)',...
'Recursive Prediction Error Method (Constant Trace)',...
'Recursive Prediction Error Method (EFRA)',...
'Levenberg-Marquardt method',...
'quit');
%---------- Define network structure and initialize weights ----------
rand('seed',0);
W1 = rand(4,2); % Weights to hidden layer
W2 = rand(1,5); % Weights to output
NetDef = ['HHHH'
'L---']; % The top row represents the structure of
% the hidden layer. The bottom row represents
% the output structure.
% ----- Back propagation (Batch) -----
if k==1,
maxiter = 1000;
stop_crit = 1e-12;
eta = 0.01;
trparms=[maxiter stop_crit eta];
[W1,W2,PI_vector,iter]=batbp(NetDef,W1,W2,PHI1,Y1,trparms);
% ----- Back propagation (Recursive) -----
elseif k==2,
maxiter = 1000;
stop_crit = 1e-12;
eta = 0.01;
trparms=[maxiter stop_crit eta];
[W1,W2,PI_vector,iter]=batbp(NetDef,W1,W2,PHI1,Y1,trparms);
% ----- RPE algorithm (Forgetting factor) -----
elseif k==3,
maxiter = 200;
stop_crit = 1e-12;
p0 = 200;
lambda = 0.98;
trparms=[maxiter stop_crit p0 lambda];
[W1,W2,PI_vector,iter]=rpe(NetDef,W1,W2,PHI1,Y1,trparms,'ff');
% ----- RPE algorithm (Constant Trace) -----
elseif k==4,
maxiter = 200;
stop_crit = 1e-12;
alpha_max=200;
alpha_min=0.001;
trparms=[maxiter stop_crit alpha_max alpha_min];
[W1,W2,PI_vector,iter]=rpe(NetDef,W1,W2,PHI1,Y1,trparms,'ct');
% ----- RPE algorithm (EFRA) -----
elseif k==5,
maxiter = 200;
stop_crit = 1e-12;
alpha=1;
beta=0.001;
delta=0.001;
lambda=0.98;
trparms=[maxiter stop_crit alpha beta delta lambda];
[W1,W2,PI_vector,iter]=rpe(NetDef,W1,W2,PHI1,Y1,trparms,'efra');
% ----- Marquardt algorithm -----
elseif k==6,
maxiter = 200;
stop_crit = 1e-12;
lambda=1;
D=0;
trparms=[maxiter stop_crit lambda D];
[W1,W2,PI_vector,iter,lambda]=marq(NetDef,W1,W2,PHI1,Y1,trparms);
end
if k~=7,
% ----------- Validate Network -----------
[Y_sim,E,PI] = nneval(NetDef,W1,W2,PHI2,Y2);
% ----------- Plot Cost function -----------
figure
semilogy(PI_vector)
title('Criterion evaluated after each iteration')
xlabel('Iteration (epoch)')
ylabel('Criterion')
grid
end
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -