traincgp_snn.m
来自「神经网络的工具箱, 神经网络的工具箱,」· M 代码 · 共 213 行
M
213 行
function [net, result] = traincgp_snn(net, dataLV, dataVV, dataTV)%TRAINCGP_SNN Conjugate gradient training (Polak - Ribiere).%% Syntax%% [net, tr_info] = traincgp_snn(net, dataLV)% [net, tr_info] = traincgp_snn(net, dataLV, dataVV)% [net, tr_info] = traincgp_snn(net, dataLV, dataVV, dataTV)%% trainfcn_struct = traincgp_snn('pdefaults')%% Description%% TRAINCGP_SNN is a network training function that updates weight and% bias values according to a conjugate gradient algorithm as% proposed by Polak and Ribiere.%% TRAINCGP_SNN takes% net - a net_struct containing the initial network.% dataLV - training data set.% dataVV - validation data set (optional).% dataTV - test data set (optional).% and returns% net - the trained network.% tr_info - a structure containing information about the% training process.%% The field 'trainFcn' of 'net' must contain a structure with the% training parameters. The default parameters can be set with%% net.trainFcn = traincgp_snn('pdefaults')%% or can be set manually. The parameters are:%% net.trainFcn.name - 'traincgp_snn'% net.trainFcn.epochs - maximum number of training epochs% net.trainFcn.goal - training goal for cost function% net.trainFcn.max_fail - maximum number of epochs in which cost% on validation set increases.% net.trainFcn.max_suc_fail - maximum number of succesive epochs% in which cost on validation set% increases.% net.trainFcn.tol - tolerance parameter. % net.trainFcn.searchFcn - 'srchbre_snn' (or another line search function)% net.trainFcn.progressFcn - progressfcn_struct containing parameters% for showing progress of training% net.trainFcn.time - maximum time for training.%% See also%% NET_STRUCT_SNN, PROGRESSFCN_STRUCT_SNN%% FUNCTION INFO% =============if isstr(net) switch (net) case 'pnames', net = {'name';'epochs';'goal';'lr';'max_fail';'min_grad';... 'show';'time'}; case 'pdefaults', trainFcn.name = 'traincgp_snn'; trainFcn.epochs = 1000; trainFcn.goal = 0; trainFcn.max_fail = 20; trainFcn.max_suc_fail = 10; trainFcn.tol = 1e-7; trainFcn.searchFcn = 'srchbre_snn'; trainFcn.progressFcn = progress_snn('pdefaults'); trainFcn.time = inf; net = trainFcn; otherwise, error('Unrecognized code.') end returnend % CALCULATION% ===========if (nargin <3) doValidation = 0;else doValidation = 1;endif (nargin <4) doTest = 0;else doTest = 1;end% Constantsepochs = net.trainFcn.epochs;goal = net.trainFcn.goal;time = net.trainFcn.time;max_fail = net.trainFcn.max_fail;max_suc_fail = net.trainFcn.max_suc_fail;tol = net.trainFcn.tol;% Initializeresult.stop = '';startTime = clock;X = getx_snn(net);result.tr = tr_struct_snn(epochs, 'perf', 'vperf', 'tperf');%#function se_snn%#function relerr_snn%#function loglikelihood_snn%#function crosslogistic_snn%#function crossentropy_snnresult.tr.perf(1) = feval(net.costFcn.name, net, dataLV);fX = result.tr.perf(1);if (doValidation) result.best_X = X; result.tr.vperf(1) = feval(net.costFcn.name, net, dataVV); result.vperf_min = result.tr.vperf(1); result.num_fail = 0; result.num_suc_fail = 0;endif (doTest) result.tr.tperf(1) = feval(net.costFcn.name, net, dataTV);endxi = gradff_snn(net, dataLV);g = -xi;xi = g;h = g;gg = 0;% Trainfor epoch = 1:epochs % Train with Conjugate Gradient (Polak Ribiere) %#function srchbre_snn [X, fmin] = feval(net.trainFcn.searchFcn, ... net.costFcn.name, X, xi, net, dataLV); if (2*abs(fmin - fX) <= tol*(abs(fmin)+abs(fX)) + 1e-10) result.stop = 'Stop; minimum reached for costFcn, performance goal was not met.'; else fX = fmin; xi = gradff_snn(X, net, dataLV); gg = g'*g; dgg = (xi+g)'*xi; if (~gg) result.stop = 'Stop; gradient for costFcn = 0, performance goal was not met.'; else gam = dgg/gg; g = -xi; xi = g + gam*h; h = xi; end end net = setx_snn(net, X); % Save performance epochPlus1 = epoch + 1; result.tr.perf(epochPlus1) = fmin; % Validation if (doValidation) %#function wcf_snn result.tr.vperf(epochPlus1) = feval(net.costFcn.name, net, dataVV); if (result.tr.vperf(epochPlus1) < result.vperf_min) result.vperf_min = result.tr.vperf(epochPlus1); result.best_X = X; result.num_fail = 0; elseif (result.tr.vperf(epochPlus1) > result.vperf_min) result.num_fail = result.num_fail + 1; end if (result.tr.vperf(epochPlus1) > result.tr.vperf(epoch)) result.num_suc_fail = result.num_suc_fail + 1; else result.num_suc_fail = 0; end end if (doTest) %#function wcf_snn result.tr.tperf(epochPlus1) = feval(net.costFcn.name, net, dataTV); end % Stopping criteria result.runtime = etime(clock, startTime); if (result.tr.perf(epochPlus1) <= goal) result.stop = 'Performance goal met.'; elseif (epoch == epochs) result.stop = 'Maximum epoch reached, performance goal was not met.'; elseif (result.runtime > time) result.stop = 'Maximum time elapsed, performance goal was not met.'; elseif (doValidation) & (result.num_fail > max_fail) result.stop = 'Validation stop; maximum total failures'; net = setx_snn(net, result.best_X); elseif (doValidation) & (result.num_suc_fail > max_suc_fail) result.stop = 'Validation stop; maximum successive failures'; net = setx_snn(net, result.best_X); end result.epoch = epoch; if length(result.stop) %#function progress_snn feval(net.trainFcn.progressFcn.name, net, [], result); stdout_snn('%s, %s\n', upper(net.trainFcn.name), result.stop); break; end % Show progress %#function progress_snn feval(net.trainFcn.progressFcn.name, net, [], result);end% Finishresult.tr = cliptr_snn(result.tr, epoch);
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?