⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 train.m

📁 RBF网络逼近、RBF-建模、RBF网络训练与测试程序
💻 M
字号:
function [net,tr,Y,E,Pf,Af]=train(net,P,T,Pi,Ai,VV,TV)
%TRAIN Train a neural network.
%
%  Syntax
%
%    [net,tr,Y,E,Pf,Af] = train(NET,P,T,Pi,Ai,VV,TV)
%
%  Description
%
%    TRAIN trains a network NET according to NET.trainFcn and
%    NET.trainParam.
%
%    TRAIN(NET,P,T,Pi,Ai) takes,
%      NET - Network.
%      P   - Network inputs.
%      T   - Network targets, default = zeros.
%      Pi  - Initial input delay conditions, default = zeros.
%      Ai  - Initial layer delay conditions, default = zeros.
%      VV  - Structure of validation vectors, default = [].
%      TV  - Structure of test vectors, default = [].
%    and returns,
%      NET - New network.
%      TR  - Training record (epoch and perf).
%      Y   - Network outputs.
%      E   - Network errors.
%      Pf  - Final input delay conditions.
%      Af  - Final layer delay conditions.
%
%    Note that T is optional and need only be used for networks
%    that require targets.  Pi and Pf are also optional and need
%    only be used for networks that have input or layer delays.
%    Optional arguments VV and TV are described below.
%
%    TRAIN's signal arguments can have two formats: cell array or matrix.
%    
%    The cell array format is easiest to describe.  It is most
%    convenient for networks with multiple inputs and outputs,
%    and allows sequences of inputs to be presented:
%      P  - NixTS cell array, each element P{i,ts} is an RixQ matrix.
%      T  - NtxTS cell array, each element P{i,ts} is an VixQ matrix.
%      Pi - NixID cell array, each element Pi{i,k} is an RixQ matrix.
%      Ai - NlxLD cell array, each element Ai{i,k} is an SixQ matrix.
%      Y  - NOxTS cell array, each element Y{i,ts} is an UixQ matrix.
%      E  - NtxTS cell array, each element P{i,ts} is an VixQ matrix.
%      Pf - NixID cell array, each element Pf{i,k} is an RixQ matrix.
%      Af - NlxLD cell array, each element Af{i,k} is an SixQ matrix.
%    Where:
%      Ni = net.numInputs
%      Nl = net.numLayers
%      Nt = net.numTargets
%      ID = net.numInputDelays
%      LD = net.numLayerDelays
%      TS = number of time steps
%      Q  = batch size
%      Ri = net.inputs{i}.size
%      Si = net.layers{i}.size
%      Vi = net.targets{i}.size
%
%    The columns of Pi, Pf, Ai, and Af are ordered from the oldest delay
%    condition to most recent:
%      Pi{i,k} = input i at time ts=k-ID.
%      Pf{i,k} = input i at time ts=TS+k-ID.
%      Ai{i,k} = layer output i at time ts=k-LD.
%      Af{i,k} = layer output i at time ts=TS+k-LD.
%
%    The matrix format can be used if only one time step is to be
%    simulated (TS = 1).  It is convenient for network's with
%     only one input and output, but can be used with networks that
%     have more.
%
%    Each matrix argument is found by storing the elements of
%    the corresponding cell array argument into a single matrix:
%      P  - (sum of Ri)xQ matrix
%      T  - (sum of Vi)xQ matrix
%      Pi - (sum of Ri)x(ID*Q) matrix.
%      Ai - (sum of Si)x(LD*Q) matrix.
%      Y  - (sum of Ui)xQ matrix.
%      E  - (sum of Vi)xQ matrix
%      Pf - (sum of Ri)x(ID*Q) matrix.
%      Af - (sum of Si)x(LD*Q) matrix.
%
%    If VV and TV are supplied they should be an empty matrix [] or
%    a structure with the following fields:
%      VV.P,  TV.P  - Validation/test inputs.
%      VV.T,  TV.T  - Validation/test targets, default = zeros.
%      VV.Pi, TV.Pi - Validation/test initial input delay conditions, default = zeros.
%      VV.Ai, TV.Ai - Validation/test layer delay conditions, default = zeros.
%    The validation vectors are used to stop training early if further
%    training on the primary vectors will hurt generalization to the
%    validation vectors.  Test vector performance can be used to measure
%    how well the network generalizes beyond primary and validation vectors.
%    If VV.T, VV.Pi, or VV.Ai are set to an empty matrix or cell array,
%    default values will be used. The same is true for TV.T, TV.Pi, TV.Ai.
%
%    Not all training functions support validation and test vectors.
%    Those that do not ignore the VV and TV arguments.
%
%  Examples
%
%    Here input P and targets T define a simple function which
%    we can plot:
%
%      p = [0 1 2 3 4 5 6 7 8];
%      t = [0 0.84 0.91 0.14 -0.77 -0.96 -0.28 0.66 0.99];
%      plot(p,t,'o')
%
%    Here NEWFF is used to create a two layer feed forward network.
%    The network will have an input (ranging from 0 to 8), followed
%    by a layer of 10 TANSIG neurons, followed by a layer with 1
%    PURELIN neuron.  TRAINLM backpropagation is used.  The network
%    is also simulated.
%
%      net = newff([0 8],[10 1],{'tansig' 'purelin'},'trainlm');
%      y1 = sim(net,p)
%      plot(p,t,'o',p,y1,'x')
%
%    Here the network is trained for up to 50 epochs to a error goal of
%    0.01, and then resimulated.
%
%      net.trainParam.epochs = 50;
%      net.trainParam.goal = 0.01;
%      net = train(net,p,t);
%      y2 = sim(net,p)
%      plot(p,t,'o',p,y1,'x',p,y2,'*')
%      
%  Algorithm
%
%    TRAIN calls the function indicated by NET.trainFcn, using the
%    training parameter values indicated by NET.trainParam.
%
%    Typically one epoch of training is defined as a single presentation
%    of all input vectors to the network.  The network is then updated
%    according to the results of all those presentations.
%
%    Training occurs until a maximum number of epochs occurs, the
%    performance goal is met, or any other stopping condition of the
%    function NET.trainFcn occurs.
%
%    Some training functions depart from this norm by presenting only
%    one input vector (or sequence) each epoch. An input vector (or sequence)
%    is chosen randomly each epoch from concurrent input vectors (or sequences).
%    NEWC and NEWSOM return networks that use TRAINR, a training function
%    that does this.
%
%  See also INIT, REVERT, SIM, ADAPT

%  Mark Beale, 11-31-97
%  Copyright 1992-2004 The MathWorks, Inc.
%  $Revision: 1.11.4.2 $ $Date: 2004/08/17 21:42:13 $

% CHECK AND FORMAT ARGUMENTS
% --------------------------

if nargin < 2
  error('Not enough input arguments.');
end
if ~isa(net,'network')
  error('First argument is not a network.');
end
if net.hint.zeroDelay
  error('Network contains a zero-delay loop.')
end
switch nargin
  case 2
    [err,P,T,Pi,Ai,Q,TS,matrixForm] = trainargs(net,P);
  VV = [];
  TV = [];
  case 3
    [err,P,T,Pi,Ai,Q,TS,matrixForm] = trainargs(net,P,T);
  VV = [];
  TV = [];
  case 4
    [err,P,T,Pi,Ai,Q,TS,matrixForm] = trainargs(net,P,T,Pi);
  VV = [];
  TV = [];
  case 5
    [err,P,T,Pi,Ai,Q,TS,matrixForm] = trainargs(net,P,T,Pi,Ai);
  VV = [];
  TV = [];
  case 6
    [err,P,T,Pi,Ai,Q,TS,matrixForm] = trainargs(net,P,T,Pi,Ai);
  if isempty(VV)
    VV = [];
  else
    if ~hasfield(VV,'P'), error('VV.P must be defined or VV must be [].'), end
    if ~hasfield(VV,'T'), VV.T = []; end
    if ~hasfield(VV,'Pi'), VV.Pi = []; end
    if ~hasfield(VV,'Ai'), VV.Ai = []; end
    [err,VV.P,VV.T,VV.Pi,VV.Ai,VV.Q,VV.TS,matrixForm] = ...
      trainargs(net,VV.P,VV.T,VV.Pi,VV.Ai);
      if length(err)
      disp('??? Error with validation vectors using ==> train')
      error(err)
    end
  end
  TV = [];
  case 7
    [err,P,T,Pi,Ai,Q,TS,matrixForm] = trainargs(net,P,T,Pi,Ai);
  if isempty(VV)
    VV = [];
  else
    if ~hasfield(VV,'P'), error('VV.P must be defined or VV must be [].'), end
    if ~hasfield(VV,'T'), VV.T = []; end
    if ~hasfield(VV,'Pi'), VV.Pi = []; end
    if ~hasfield(VV,'Ai'), VV.Ai = []; end
    [err,VV.P,VV.T,VV.Pi,VV.Ai,VV.Q,VV.TS,matrixForm] = ...
      trainargs(net,VV.P,VV.T,VV.Pi,VV.Ai);
      if length(err)
      disp('??? Error with validation vectors using ==> train')
      error(err)
    end
  end
  if isempty(TV)
    TV = [];
  else
    if ~hasfield(TV,'P'), error('TV.P must be defined or TV must be [].'), end
    if ~hasfield(TV,'T'), TV.T = []; end
    if ~hasfield(TV,'Pi'), TV.Pi = []; end
    if ~hasfield(TV,'Ai'), TV.Ai = []; end
    [err,TV.P,TV.T,TV.Pi,TV.Ai,TV.Q,TV.TS,matrixForm] = ...
      trainargs(net,TV.P,TV.T,TV.Pi,TV.Ai);
      if length(err)
      disp('??? Error with test vectors using ==> train')
      error(err)
    end
  end
end
if length(err), error(err), end

% TRAIN NETWORK
% -------------

% Training function
trainFcn = net.trainFcn;
if ~length(trainFcn)
  error('Network "trainFcn" is undefined.')
end

% Delayed inputs, layer targets
Pc = [Pi P];
Pd = calcpd(net,TS,Q,Pc);
Tl = expandrows(T,net.hint.targetInd,net.numLayers);

% Validation and Test vectors
doValidation = ~isempty(VV);
doTest = ~isempty(TV);
if (doValidation)
  VV.Pd = calcpd(net,VV.TS,VV.Q,[VV.Pi VV.P]);
  VV.Tl = expandrows(VV.T,net.hint.targetInd,net.numLayers);
end
if (doTest)
  TV.Pd = calcpd(net,TV.TS,TV.Q,[TV.Pi TV.P]);
  TV.Tl = expandrows(TV.T,net.hint.targetInd,net.numLayers);
end

% Train network
net = struct(net);
flatten_time = (net.numLayerDelays == 0) & (TS > 1) & (~strcmp(trainFcn,'trains'));
if flatten_time
  Pd = seq2con(Pd);
  Tl = seq2con(Tl);
  if (doValidation)
    VV.Pd = seq2con(VV.Pd);
    VV.Tl = seq2con(VV.Tl);
    VV.Q = VV.Q*VV.TS;
    VV.TS = 1;
  end
  if (doTest)
    TV.Pd = seq2con(TV.Pd);
    TV.Tl = seq2con(TV.Tl);
    TV.Q = TV.Q*TV.TS;
    TV.TS = 1;
  end
  [net,tr,Ac,El] = feval(trainFcn,net,Pd,Tl,Ai,Q*TS,1,VV,TV);
  Ac = con2seq(Ac,TS);
  El = con2seq(El,TS);
else
  [net,tr,Ac,El] = feval(trainFcn,net,Pd,Tl,Ai,Q,TS,VV,TV);
end
net = class(net,'network');

% Network outputs, errors, final inputs
Y = Ac(net.hint.outputInd,net.numLayerDelays+[1:TS]);
E = El(net.hint.targetInd,:);
Pf = Pc(:,TS+[1:net.numInputDelays]);
Af = Ac(:,TS+[1:net.numLayerDelays]);

% FORMAT OUTPUT ARGUMENTS
% -----------------------

if (matrixForm)
  Y = cell2mat(Y);
  E = cell2mat(E);
  Pf = cell2mat(Pf);
  Af = cell2mat(Af);
end

% CLEAN UP TEMPORARY FILES
% ------------------------
nntempdir=fullfile(tempdir,'matlab_nnet');
if exist(nntempdir)
    rmpath(nntempdir);
    tf=dir(nntempdir);
    for i=1:length(tf)
        if (~tf(i).isdir)
          delete(fullfile(nntempdir,tf(i).name));
        end
    end
    rmdir(nntempdir)
end

% ============================================================
function [s2] = expandrows(s,ind,rows)

s2 = cell(rows,size(s,2));
s2(ind,:) = s;

% ============================================================
function [err,P,T,Pi,Ai,Q,TS,matrixForm] = trainargs(net,P,T,Pi,Ai);

err = '';

% Check signals: all matrices or all cell arrays
% Change empty matrices/arrays to proper form
switch class(P)
  case 'cell', matrixForm = 0; name = 'cell array'; default = {};
  case {'double','logical'}, matrixForm = 1; name = 'matrix'; default = [];
  otherwise, err = 'P must be a matrix or cell array.'; return
end
if (nargin < 3)
  T = default;
elseif ((isa(T,'double')|isa(T,'logical')) ~= matrixForm)
  if isempty(T)
    T = default;
  else
    err = ['P is a ' name ', so T must be a ' name ' too.'];
    return
  end
end
if (nargin < 4)
  Pi = default;
elseif ((isa(Pi,'double')|isa(Pi,'logical')) ~= matrixForm)
  if isempty(Pi)
    Pi = default;
  else
    err = ['P is a ' name ', so Pi must be a ' name ' too.'];
    return
  end
end
if (nargin < 5)
  Ai = default;
elseif ((isa(Ai,'double')|isa(Ai,'logical')) ~= matrixForm)
  if isempty(Ai)
    Ai = default;
  else
    err = ['P is a ' name ', so Ai must be a ' name ' too.'];
    return
  end
end

% Check Matrices, Matrices -> Cell Arrays
if (matrixForm)
  [R,Q] = size(P);
  TS = 1;
  [err,P] = formatp(net,P,Q); if length(err), return, end
  [err,T] = formatt(net,T,Q,TS); if length(err), return, end
  [err,Pi] = formatpi(net,Pi,Q); if length(err), return, end
  [err,Ai] = formatai(net,Ai,Q); if length(err), return, end
  
% Check Cell Arrays
else
  [R,TS] = size(P);
  [R1,Q] = size(P{1,1});
  [err] = checkp(net,P,Q,TS); if length(err), return, end
  [err,T] = checkt(net,T,Q,TS); if length(err), return, end
  [err,Pi] = checkpi(net,Pi,Q); if length(err), return, end
  [err,Ai] = checkai(net,Ai,Q); if length(err), return, end
end

% ============================================================

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -