traingd_snn.m
来自「神经网络的工具箱, 神经网络的工具箱,」· M 代码 · 共 185 行
M
185 行
function [net, result] = traingd_snn(net, dataLV, dataVV, dataTV)%TRAINGD_SNN Gradient Descent training. %% Syntax%% [net, tr_info] = traingd_snn(net, dataLV)% [net, tr_info] = traingd_snn(net, dataLV, dataVV)% [net, tr_info] = traingd_snn(net, dataLV, dataVV, dataTV)%% trainfcn_struct = traingd_snn('pdefaults')%% Description%% TRAINGD_SNN is a network training function that adjusts weight and% bias values in the steepest descent direction.%% TRAINGD_SNN takes% net - a net_struct containing the initial network.% dataLV - training data set.% dataVV - validation data set (optional).% dataTV - test data set (optional).% and returns% net - the trained network.% tr_info - a structure containing information about the% training process.%% The field 'trainFcn' of 'net' must contain a structure with the% training parameters. The default parameters can be set with%% net.trainFcn = traingd_snn('pdefaults')%% or can be set manually. The parameters are:%% net.trainFcn.name - 'traingd_snn'% net.trainFcn.epochs - maximum number of training epochs% net.trainFcn.goal - training goal for cost function% net.trainFcn.lr - learning rate% net.trainFcn.max_fail - maximum number of epochs in which cost% on validation set increases.% net.trainFcn.max_suc_fail - maximum number of succesive epochs% in which cost on validation set% increases.% net.trainFcn.progressFcn - progressfcn_struct containing parameters% for showing progress of training% net.trainFcn.time - maximum time for training.%% See also%% NET_STRUCT_SNN, PROGRESSFCN_STRUCT_SNN%% FUNCTION INFO% =============if isstr(net) switch (net) case 'pnames', net = {'name';'epochs';'goal';'lr';'max_fail';'max_suc_fail';... 'show';'time'}; case 'pdefaults', trainFcn.name = 'traingd_snn'; trainFcn.epochs = 10000; trainFcn.goal = 0.0; trainFcn.lr = 0.001; trainFcn.max_fail = 20; trainFcn.max_suc_fail = 10; trainFcn.progressFcn = progress_snn('pdefaults'); trainFcn.time = inf; net = trainFcn; otherwise, error('Unrecognized code.') end returnend% CALCULATION% ===========data = dataLV;if (nargin <3) doValidation = 0;else doValidation = 1; data = [data dataVV];endif (nargin <4) doTest = 0;else doTest = 1; data = [data dataTV];end% Constantsepochs = net.trainFcn.epochs;goal = net.trainFcn.goal;lr = net.trainFcn.lr;max_fail = net.trainFcn.max_fail;max_suc_fail = net.trainFcn.max_suc_fail;time = net.trainFcn.time;% Initializeresult.stop = '';startTime = clock;result.tr = tr_struct_snn(epochs, 'perf', 'vperf', 'tperf');%#function wcf_snnresult.tr.perf(1) = feval(net.costFcn.name, net, dataLV);X = getx_snn(net);if (doValidation) result.best_X = X; result.tr.vperf(1) = feval(net.costFcn.name, net, dataVV); result.vperf_min = result.tr.vperf(1); result.num_fail = 0; result.num_suc_fail = 0;endif (doTest) result.tr.tperf(1) = feval(net.costFcn.name, net, dataTV);end% Trainfor epoch = 1:epochs % Train with Gradient Descent gX = gradff_snn(net, dataLV); X = X - lr*gX; net = setx_snn(net, X); % Save performance epochPlus1 = epoch + 1; result.tr.perf(epochPlus1) = feval(net.costFcn.name, net, dataLV); % Validation if (doValidation) result.tr.vperf(epochPlus1) = feval(net.costFcn.name, net, dataVV); if (result.tr.vperf(epochPlus1) < result.vperf_min) result.vperf_min = result.tr.vperf(epochPlus1); result.best_X = X; result.num_fail = 0; elseif (result.tr.vperf(epochPlus1) > result.vperf_min) result.num_fail = result.num_fail + 1; end if (result.tr.vperf(epochPlus1) > result.tr.vperf(epoch)) result.num_suc_fail = result.num_suc_fail + 1; else result.num_suc_fail = 0; end end if (doTest) result.tr.tperf(epochPlus1) = feval(net.costFcn.name, net, dataTV); end % Stopping criteria result.runtime = etime(clock, startTime); if (result.tr.perf(epochPlus1) <= goal) result.stop = 'Performance goal met.'; elseif (epoch == epochs) result.stop = 'Maximum epoch reached, performance goal was not met.'; elseif (result.runtime > time) result.stop = 'Maximum time elapsed, performance goal was not met.'; elseif (doValidation) & (result.num_fail > max_fail) result.stop = 'Validation stop; maximum total failures'; net = setx_snn(net, result.best_X); elseif (doValidation) & (result.num_suc_fail > max_suc_fail) result.stop = 'Validation stop; maximum successive failures'; net = setx_snn(net, result.best_X); end result.epoch = epoch; if length(result.stop) %#function progress_snn feval(net.trainFcn.progressFcn.name, net, data, result); stdout_snn('%s, %s\n', upper(net.trainFcn.name), result.stop); break; end % Show progress %#function progress_snn feval(net.trainFcn.progressFcn.name, net, data, result);end% Finishresult.tr = cliptr_snn(result.tr, epoch);
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?