trainlm_snn.m

来自「神经网络的工具箱, 神经网络的工具箱,」· M 代码 · 共 217 行

M
217
字号
function [net, result] = trainlm_snn(net, dataLV, dataVV, dataTV)%TRAINLM_SNN Levenberg-Marquardt backpropagation.%%  Syntax%%   [net, tr_info] = trainlm_snn(net, dataLV)%   [net, tr_info] = trainlm_snn(net, dataLV, dataVV)%   [net, tr_info] = trainlm_snn(net, dataLV, dataVV, dataTV)%%   trainfcn_struct = trainlm_snn('pdefaults')%%  Description%%   TRAINLM_SNN is a network training function that updates weight and%   bias values according to Levenberg-Marquardt optimization.%   %   TRAINLM_SNN takes%     net    - a net_struct containing the initial network.%     dataLV - training data set.%     dataVV - validation data set (optional).%     dataTV - test data set (optional).%   and returns %     net     - the trained network.%     tr_info - a structure containing information about the%               training process.%%   The field 'trainFcn' of 'net' must contain a structure with the%   training parameters. The default parameters can be set with%%     net.trainFcn = trainlm_snn('pdefaults')%%   or can be set manually. The parameters are:%%     net.trainFcn.name     - 'trainlm_snn'%     net.trainFcn.epochs   - maximum number of training epochs %     net.trainFcn.goal     - training goal for cost function %     net.trainFcn.max_fail - maximum number of epochs in which cost %                             on validation set increases.%     net.trainFcn.max_suc_fail - maximum number of succesive epochs%                                 in which cost on validation set increases.%     net.trainFcn.mu       - initial mu%     net.trainFcn.mu_dec   - factor for decreasing mu%     net.trainFcn.mu_inc   - factor for increasing mu%     net.trainFcn.mu_max   - maximum for mu%     net.trainFcn.progressFcn - progressfcn_struct containing parameters%                                for showing progress of training%     net.trainFcn.time     - maximum time for training.%%  See also%%   NET_STRUCT_SNN, PROGRESSFCN_STRUCT_SNN %%% FUNCTION INFO% =============if isstr(net)  switch (net)    case 'pnames',          net = fieldnames(trainlm_snn('pdefaults'));    case 'pdefaults',	  trainFcn.name = 'trainlm_snn';          trainFcn.epochs = 1000;          trainFcn.goal = 0;          trainFcn.max_fail = 20;          trainFcn.max_suc_fail = 10;	  trainFcn.mem_reduc = 1;	  trainFcn.mu = 0.001;	  trainFcn.mu_dec = 0.1;	  trainFcn.mu_inc = 10;	  trainFcn.mu_max = 1e10;          trainFcn.progressFcn = progress_snn('pdefaults');          trainFcn.time = inf;          net = trainFcn;    otherwise,          error('Unrecognized code.')  end     returnend             % CALCULATION% ===========data = dataLV;if (nargin <3)   doValidation = 0;else   doValidation = 1;   data = [data dataVV];endif (nargin <4)   doTest = 0;else   doTest = 1;   data = [data dataTV];end% Constantsepochs = net.trainFcn.epochs;goal = net.trainFcn.goal;max_fail = net.trainFcn.max_fail;max_suc_fail = net.trainFcn.max_suc_fail;%mem_reduc = net.trainFcn.mem_reduc;mu = net.trainFcn.mu;mu_inc = net.trainFcn.mu_inc;mu_dec = net.trainFcn.mu_dec;mu_max = net.trainFcn.mu_max;time = net.trainFcn.time;% Initializeresult.stop = '';startTime = clock;X = getx_snn(net);numParameters = length(X);ii = sparse(1:numParameters, 1:numParameters, ones(1, numParameters));result.tr = tr_struct_snn(epochs, 'perf', 'vperf', 'tperf', 'mu');%#function wcf_snnresult.tr.perf(1) = feval(net.costFcn.name, net, dataLV);result.tr.mu(1) = mu;if (doValidation)   result.best_X = X;   result.tr.vperf(1) = feval(net.costFcn.name, net, dataVV);   result.vperf_min = result.tr.vperf(1);   result.num_fail = 0;   result.num_suc_fail = 0;endif (doTest)   result.tr.tperf(1) = feval(net.costFcn.name, net, dataTV);end% Trainfor epoch = 1:epochs     % Train with Levenberg Marquardt    [jj, je] = fisherff_snn(X, net, dataLV);    while (mu <= mu_max)        dX = -(jj+ii*mu) \ je;        X2 = X + dX;        net2 = setx_snn(net,X2);        perf2 = feval(net.costFcn.name, net2, dataLV);        if (perf2 < result.tr.perf(epoch))          X = X2;	  perf = perf2;          mu = mu * mu_dec;          break        end        mu = mu * mu_inc;    end    if (mu> mu_max)       result.stop = 'Maximum MU reached, performance goal was not met.';       perf = result.tr.perf(epoch);    end    net = setx_snn(net, X);     % Save performance    epochPlus1 = epoch + 1;    result.tr.perf(epochPlus1) = perf;    result.tr.mu(epochPlus1) = mu;            % Validation    if (doValidation)       result.tr.vperf(epochPlus1) = feval(net.costFcn.name, net, dataVV);       if (result.tr.vperf(epochPlus1) < result.vperf_min)          result.vperf_min = result.tr.vperf(epochPlus1);          result.best_X = X;          result.num_fail = 0;       elseif (result.tr.vperf(epochPlus1) > result.vperf_min)          result.num_fail = result.num_fail + 1;       end       if (result.tr.vperf(epochPlus1) > result.tr.vperf(epoch))          result.num_suc_fail = result.num_suc_fail + 1;       else          result.num_suc_fail = 0;       end    end    if (doTest)       result.tr.tperf(epochPlus1) = feval(net.costFcn.name, net, dataTV);    end        % Stopping criteria    result.runtime = etime(clock, startTime);    if (result.tr.perf(epochPlus1) <= goal)       result.stop = 'Performance goal met.';    elseif (epoch == epochs)       result.stop = 'Maximum epoch reached, performance goal was not met.';    elseif (result.runtime > time)       result.stop = 'Maximum time elapsed, performance goal was not met.';    elseif (doValidation) & (result.num_fail > max_fail)       result.stop = 'Validation stop; maximum total failures';       net = setx_snn(net, result.best_X);    elseif (doValidation) & (result.num_suc_fail > max_suc_fail)       result.stop = 'Validation stop; maximum successive failures';       net = setx_snn(net, result.best_X);    elseif (doValidation) & (mu> mu_max)       net = setx_snn(net, result.best_X);    end    result.epoch = epoch;    if length(result.stop)        %#function progress_snn       feval(net.trainFcn.progressFcn.name, net, data, result);       stdout_snn('%s, %s\n', upper(net.trainFcn.name), result.stop);        break;     end    % Show progress    %#function progress_snn    feval(net.trainFcn.progressFcn.name, net, data, result);end% Finishresult.tr = cliptr_snn(result.tr, epoch);

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?