traingd_snn.m

来自「神经网络的工具箱, 神经网络的工具箱,」· M 代码 · 共 185 行

M
185
字号
function [net, result] = traingd_snn(net, dataLV, dataVV, dataTV)%TRAINGD_SNN Gradient Descent training. %%  Syntax%%   [net, tr_info] = traingd_snn(net, dataLV)%   [net, tr_info] = traingd_snn(net, dataLV, dataVV)%   [net, tr_info] = traingd_snn(net, dataLV, dataVV, dataTV)%%   trainfcn_struct = traingd_snn('pdefaults')%%  Description%%   TRAINGD_SNN is a network training function that adjusts weight and%   bias values in the steepest descent direction.%%   TRAINGD_SNN takes%     net    - a net_struct containing the initial network.%     dataLV - training data set.%     dataVV - validation data set (optional).%     dataTV - test data set (optional).%   and returns%     net     - the trained network.%     tr_info - a structure containing information about the%               training process.%%   The field 'trainFcn' of 'net' must contain a structure with the%   training parameters. The default parameters can be set with%%     net.trainFcn = traingd_snn('pdefaults')%%   or can be set manually. The parameters are:%%     net.trainFcn.name     - 'traingd_snn'%     net.trainFcn.epochs   - maximum number of training epochs%     net.trainFcn.goal     - training goal for cost function%     net.trainFcn.lr       - learning rate%     net.trainFcn.max_fail - maximum number of epochs in which cost%                             on validation set increases.%     net.trainFcn.max_suc_fail - maximum number of succesive epochs%                                 in which cost on validation set%                                 increases.%     net.trainFcn.progressFcn - progressfcn_struct containing parameters%                                for showing progress of training%     net.trainFcn.time     - maximum time for training.%%  See also%%   NET_STRUCT_SNN, PROGRESSFCN_STRUCT_SNN%% FUNCTION INFO% =============if isstr(net)  switch (net)    case 'pnames',          net = {'name';'epochs';'goal';'lr';'max_fail';'max_suc_fail';...            'show';'time'};    case 'pdefaults',	  trainFcn.name = 'traingd_snn';          trainFcn.epochs = 10000;          trainFcn.goal = 0.0;          trainFcn.lr = 0.001;          trainFcn.max_fail = 20;          trainFcn.max_suc_fail = 10;          trainFcn.progressFcn = progress_snn('pdefaults');          trainFcn.time = inf;          net = trainFcn;    otherwise,          error('Unrecognized code.')  end  returnend% CALCULATION% ===========data = dataLV;if (nargin <3)   doValidation = 0;else   doValidation = 1;   data = [data dataVV];endif (nargin <4)   doTest = 0;else   doTest = 1;   data = [data dataTV];end% Constantsepochs = net.trainFcn.epochs;goal = net.trainFcn.goal;lr = net.trainFcn.lr;max_fail = net.trainFcn.max_fail;max_suc_fail = net.trainFcn.max_suc_fail;time = net.trainFcn.time;% Initializeresult.stop = '';startTime = clock;result.tr = tr_struct_snn(epochs, 'perf', 'vperf', 'tperf');%#function wcf_snnresult.tr.perf(1) = feval(net.costFcn.name, net, dataLV);X = getx_snn(net);if (doValidation)   result.best_X = X;   result.tr.vperf(1) = feval(net.costFcn.name, net, dataVV);   result.vperf_min = result.tr.vperf(1);   result.num_fail = 0;   result.num_suc_fail = 0;endif (doTest)   result.tr.tperf(1) = feval(net.costFcn.name, net, dataTV);end% Trainfor epoch = 1:epochs     % Train with Gradient Descent    gX = gradff_snn(net, dataLV);    X = X - lr*gX;    net = setx_snn(net, X);     % Save performance    epochPlus1 = epoch + 1;    result.tr.perf(epochPlus1) = feval(net.costFcn.name, net, dataLV);            % Validation    if (doValidation)       result.tr.vperf(epochPlus1) = feval(net.costFcn.name, net, dataVV);       if (result.tr.vperf(epochPlus1) < result.vperf_min)          result.vperf_min = result.tr.vperf(epochPlus1);          result.best_X = X;          result.num_fail = 0;       elseif (result.tr.vperf(epochPlus1) > result.vperf_min)          result.num_fail = result.num_fail + 1;       end       if (result.tr.vperf(epochPlus1) > result.tr.vperf(epoch))          result.num_suc_fail = result.num_suc_fail + 1;       else          result.num_suc_fail = 0;       end    end    if (doTest)       result.tr.tperf(epochPlus1) = feval(net.costFcn.name, net, dataTV);    end        % Stopping criteria    result.runtime = etime(clock, startTime);    if (result.tr.perf(epochPlus1) <= goal)       result.stop = 'Performance goal met.';    elseif (epoch == epochs)       result.stop = 'Maximum epoch reached, performance goal was not met.';    elseif (result.runtime > time)       result.stop = 'Maximum time elapsed, performance goal was not met.';    elseif (doValidation) & (result.num_fail > max_fail)       result.stop = 'Validation stop; maximum total failures';       net = setx_snn(net, result.best_X);    elseif (doValidation) & (result.num_suc_fail > max_suc_fail)       result.stop = 'Validation stop; maximum successive failures';       net = setx_snn(net, result.best_X);    end    result.epoch = epoch;    if length(result.stop)        %#function progress_snn       feval(net.trainFcn.progressFcn.name, net, data, result);       stdout_snn('%s, %s\n', upper(net.trainFcn.name), result.stop);        break;     end    % Show progress    %#function progress_snn    feval(net.trainFcn.progressFcn.name, net, data, result);end% Finishresult.tr = cliptr_snn(result.tr, epoch);

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?