trainlm_snn.m
来自「神经网络的工具箱, 神经网络的工具箱,」· M 代码 · 共 217 行
M
217 行
function [net, result] = trainlm_snn(net, dataLV, dataVV, dataTV)%TRAINLM_SNN Levenberg-Marquardt backpropagation.%% Syntax%% [net, tr_info] = trainlm_snn(net, dataLV)% [net, tr_info] = trainlm_snn(net, dataLV, dataVV)% [net, tr_info] = trainlm_snn(net, dataLV, dataVV, dataTV)%% trainfcn_struct = trainlm_snn('pdefaults')%% Description%% TRAINLM_SNN is a network training function that updates weight and% bias values according to Levenberg-Marquardt optimization.% % TRAINLM_SNN takes% net - a net_struct containing the initial network.% dataLV - training data set.% dataVV - validation data set (optional).% dataTV - test data set (optional).% and returns % net - the trained network.% tr_info - a structure containing information about the% training process.%% The field 'trainFcn' of 'net' must contain a structure with the% training parameters. The default parameters can be set with%% net.trainFcn = trainlm_snn('pdefaults')%% or can be set manually. The parameters are:%% net.trainFcn.name - 'trainlm_snn'% net.trainFcn.epochs - maximum number of training epochs % net.trainFcn.goal - training goal for cost function % net.trainFcn.max_fail - maximum number of epochs in which cost % on validation set increases.% net.trainFcn.max_suc_fail - maximum number of succesive epochs% in which cost on validation set increases.% net.trainFcn.mu - initial mu% net.trainFcn.mu_dec - factor for decreasing mu% net.trainFcn.mu_inc - factor for increasing mu% net.trainFcn.mu_max - maximum for mu% net.trainFcn.progressFcn - progressfcn_struct containing parameters% for showing progress of training% net.trainFcn.time - maximum time for training.%% See also%% NET_STRUCT_SNN, PROGRESSFCN_STRUCT_SNN %%% FUNCTION INFO% =============if isstr(net) switch (net) case 'pnames', net = fieldnames(trainlm_snn('pdefaults')); case 'pdefaults', trainFcn.name = 'trainlm_snn'; trainFcn.epochs = 1000; trainFcn.goal = 0; trainFcn.max_fail = 20; trainFcn.max_suc_fail = 10; trainFcn.mem_reduc = 1; trainFcn.mu = 0.001; trainFcn.mu_dec = 0.1; trainFcn.mu_inc = 10; trainFcn.mu_max = 1e10; trainFcn.progressFcn = progress_snn('pdefaults'); trainFcn.time = inf; net = trainFcn; otherwise, error('Unrecognized code.') end returnend % CALCULATION% ===========data = dataLV;if (nargin <3) doValidation = 0;else doValidation = 1; data = [data dataVV];endif (nargin <4) doTest = 0;else doTest = 1; data = [data dataTV];end% Constantsepochs = net.trainFcn.epochs;goal = net.trainFcn.goal;max_fail = net.trainFcn.max_fail;max_suc_fail = net.trainFcn.max_suc_fail;%mem_reduc = net.trainFcn.mem_reduc;mu = net.trainFcn.mu;mu_inc = net.trainFcn.mu_inc;mu_dec = net.trainFcn.mu_dec;mu_max = net.trainFcn.mu_max;time = net.trainFcn.time;% Initializeresult.stop = '';startTime = clock;X = getx_snn(net);numParameters = length(X);ii = sparse(1:numParameters, 1:numParameters, ones(1, numParameters));result.tr = tr_struct_snn(epochs, 'perf', 'vperf', 'tperf', 'mu');%#function wcf_snnresult.tr.perf(1) = feval(net.costFcn.name, net, dataLV);result.tr.mu(1) = mu;if (doValidation) result.best_X = X; result.tr.vperf(1) = feval(net.costFcn.name, net, dataVV); result.vperf_min = result.tr.vperf(1); result.num_fail = 0; result.num_suc_fail = 0;endif (doTest) result.tr.tperf(1) = feval(net.costFcn.name, net, dataTV);end% Trainfor epoch = 1:epochs % Train with Levenberg Marquardt [jj, je] = fisherff_snn(X, net, dataLV); while (mu <= mu_max) dX = -(jj+ii*mu) \ je; X2 = X + dX; net2 = setx_snn(net,X2); perf2 = feval(net.costFcn.name, net2, dataLV); if (perf2 < result.tr.perf(epoch)) X = X2; perf = perf2; mu = mu * mu_dec; break end mu = mu * mu_inc; end if (mu> mu_max) result.stop = 'Maximum MU reached, performance goal was not met.'; perf = result.tr.perf(epoch); end net = setx_snn(net, X); % Save performance epochPlus1 = epoch + 1; result.tr.perf(epochPlus1) = perf; result.tr.mu(epochPlus1) = mu; % Validation if (doValidation) result.tr.vperf(epochPlus1) = feval(net.costFcn.name, net, dataVV); if (result.tr.vperf(epochPlus1) < result.vperf_min) result.vperf_min = result.tr.vperf(epochPlus1); result.best_X = X; result.num_fail = 0; elseif (result.tr.vperf(epochPlus1) > result.vperf_min) result.num_fail = result.num_fail + 1; end if (result.tr.vperf(epochPlus1) > result.tr.vperf(epoch)) result.num_suc_fail = result.num_suc_fail + 1; else result.num_suc_fail = 0; end end if (doTest) result.tr.tperf(epochPlus1) = feval(net.costFcn.name, net, dataTV); end % Stopping criteria result.runtime = etime(clock, startTime); if (result.tr.perf(epochPlus1) <= goal) result.stop = 'Performance goal met.'; elseif (epoch == epochs) result.stop = 'Maximum epoch reached, performance goal was not met.'; elseif (result.runtime > time) result.stop = 'Maximum time elapsed, performance goal was not met.'; elseif (doValidation) & (result.num_fail > max_fail) result.stop = 'Validation stop; maximum total failures'; net = setx_snn(net, result.best_X); elseif (doValidation) & (result.num_suc_fail > max_suc_fail) result.stop = 'Validation stop; maximum successive failures'; net = setx_snn(net, result.best_X); elseif (doValidation) & (mu> mu_max) net = setx_snn(net, result.best_X); end result.epoch = epoch; if length(result.stop) %#function progress_snn feval(net.trainFcn.progressFcn.name, net, data, result); stdout_snn('%s, %s\n', upper(net.trainFcn.name), result.stop); break; end % Show progress %#function progress_snn feval(net.trainFcn.progressFcn.name, net, data, result);end% Finishresult.tr = cliptr_snn(result.tr, epoch);
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?