traincgp_snn.m

来自「神经网络的工具箱, 神经网络的工具箱,」· M 代码 · 共 213 行

M
213
字号
function [net, result] = traincgp_snn(net, dataLV, dataVV, dataTV)%TRAINCGP_SNN Conjugate gradient training (Polak - Ribiere).%%  Syntax%%   [net, tr_info] = traincgp_snn(net, dataLV)%   [net, tr_info] = traincgp_snn(net, dataLV, dataVV)%   [net, tr_info] = traincgp_snn(net, dataLV, dataVV, dataTV)%%   trainfcn_struct = traincgp_snn('pdefaults')%%  Description%%   TRAINCGP_SNN is a network training function that updates weight and%   bias values according to a conjugate gradient algorithm as%   proposed by Polak and Ribiere.%%   TRAINCGP_SNN takes%     net    - a net_struct containing the initial network.%     dataLV - training data set.%     dataVV - validation data set (optional).%     dataTV - test data set (optional).%   and returns%     net     - the trained network.%     tr_info - a structure containing information about the%               training process.%%   The field 'trainFcn' of 'net' must contain a structure with the%   training parameters. The default parameters can be set with%%     net.trainFcn = traincgp_snn('pdefaults')%%   or can be set manually. The parameters are:%%     net.trainFcn.name     - 'traincgp_snn'%     net.trainFcn.epochs   - maximum number of training epochs%     net.trainFcn.goal     - training goal for cost function%     net.trainFcn.max_fail - maximum number of epochs in which cost%                             on validation set increases.%     net.trainFcn.max_suc_fail - maximum number of succesive epochs%                                 in which cost on validation set%                                 increases.%     net.trainFcn.tol      - tolerance parameter. %     net.trainFcn.searchFcn - 'srchbre_snn' (or another line search function)%     net.trainFcn.progressFcn - progressfcn_struct containing parameters%                                for showing progress of training%     net.trainFcn.time     - maximum time for training.%%  See also%%   NET_STRUCT_SNN, PROGRESSFCN_STRUCT_SNN%% FUNCTION INFO% =============if isstr(net)  switch (net)    case 'pnames',          net = {'name';'epochs';'goal';'lr';'max_fail';'min_grad';...            'show';'time'};    case 'pdefaults',	  trainFcn.name = 'traincgp_snn';          trainFcn.epochs = 1000;          trainFcn.goal = 0;          trainFcn.max_fail = 20;          trainFcn.max_suc_fail = 10;          trainFcn.tol = 1e-7;          trainFcn.searchFcn = 'srchbre_snn';          trainFcn.progressFcn = progress_snn('pdefaults');          trainFcn.time = inf;          net = trainFcn;    otherwise,          error('Unrecognized code.')  end     returnend             % CALCULATION% ===========if (nargin <3)   doValidation = 0;else   doValidation = 1;endif (nargin <4)   doTest = 0;else   doTest = 1;end% Constantsepochs = net.trainFcn.epochs;goal = net.trainFcn.goal;time = net.trainFcn.time;max_fail = net.trainFcn.max_fail;max_suc_fail = net.trainFcn.max_suc_fail;tol = net.trainFcn.tol;% Initializeresult.stop = '';startTime = clock;X = getx_snn(net);result.tr = tr_struct_snn(epochs, 'perf', 'vperf', 'tperf');%#function se_snn%#function relerr_snn%#function loglikelihood_snn%#function crosslogistic_snn%#function crossentropy_snnresult.tr.perf(1) = feval(net.costFcn.name, net, dataLV);fX = result.tr.perf(1);if (doValidation)   result.best_X = X;   result.tr.vperf(1) = feval(net.costFcn.name, net, dataVV);   result.vperf_min = result.tr.vperf(1);   result.num_fail = 0;   result.num_suc_fail = 0;endif (doTest)   result.tr.tperf(1) = feval(net.costFcn.name, net, dataTV);endxi = gradff_snn(net, dataLV);g = -xi;xi = g;h = g;gg = 0;% Trainfor epoch = 1:epochs     % Train with Conjugate Gradient (Polak Ribiere)     %#function srchbre_snn    [X, fmin] = feval(net.trainFcn.searchFcn, ...                     net.costFcn.name, X, xi, net, dataLV);    if (2*abs(fmin - fX) <= tol*(abs(fmin)+abs(fX)) + 1e-10)       result.stop = 'Stop; minimum reached for costFcn, performance goal was not met.';    else       fX = fmin;       xi = gradff_snn(X, net, dataLV);       gg = g'*g;       dgg = (xi+g)'*xi;       if (~gg)          result.stop = 'Stop; gradient for costFcn = 0, performance goal was not met.';       else          gam = dgg/gg;	  g = -xi;	  xi = g + gam*h;	  h = xi;       end    end    net = setx_snn(net, X);     % Save performance    epochPlus1 = epoch + 1;    result.tr.perf(epochPlus1) = fmin;            % Validation    if (doValidation)       %#function wcf_snn       result.tr.vperf(epochPlus1) = feval(net.costFcn.name, net, dataVV);       if (result.tr.vperf(epochPlus1) < result.vperf_min)          result.vperf_min = result.tr.vperf(epochPlus1);          result.best_X = X;          result.num_fail = 0;       elseif (result.tr.vperf(epochPlus1) > result.vperf_min)          result.num_fail = result.num_fail + 1;       end       if (result.tr.vperf(epochPlus1) > result.tr.vperf(epoch))          result.num_suc_fail = result.num_suc_fail + 1;       else          result.num_suc_fail = 0;       end    end    if (doTest)       %#function wcf_snn       result.tr.tperf(epochPlus1) = feval(net.costFcn.name, net, dataTV);    end        % Stopping criteria    result.runtime = etime(clock, startTime);    if (result.tr.perf(epochPlus1) <= goal)       result.stop = 'Performance goal met.';    elseif (epoch == epochs)       result.stop = 'Maximum epoch reached, performance goal was not met.';    elseif (result.runtime > time)       result.stop = 'Maximum time elapsed, performance goal was not met.';    elseif (doValidation) & (result.num_fail > max_fail)       result.stop = 'Validation stop; maximum total failures';       net = setx_snn(net, result.best_X);    elseif (doValidation) & (result.num_suc_fail > max_suc_fail)       result.stop = 'Validation stop; maximum successive failures';       net = setx_snn(net, result.best_X);    end    result.epoch = epoch;    if length(result.stop) 	%#function progress_snn        feval(net.trainFcn.progressFcn.name, net, [], result);       stdout_snn('%s, %s\n', upper(net.trainFcn.name), result.stop);        break;     end    % Show progress    %#function progress_snn    feval(net.trainFcn.progressFcn.name, net, [], result);end% Finishresult.tr = cliptr_snn(result.tr, epoch);

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?