⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 glmtrain.m

📁 模式识别的主要工具集合
💻 M
字号:
function [net, options] = glmtrain(net, options, x, t)%GLMTRAIN Specialised training of generalized linear model%%	Description%	NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted%	least squares (IRLS) algorithm to set the weights in the generalized%	linear model structure NET.  This is a more efficient alternative to%	using GLMERR and GLMGRAD and a non-linear optimisation routine%	through NETOPT. Note that for linear outputs, a single pass through%	the  algorithm is all that is required, since the error function is%	quadratic in the weights.  The algorithm also handles scalar ALPHA%	and BETA terms.  If you want to use more complicated priors, you%	should use general-purpose non-linear optimisation algorithms.%%	For logistic and softmax outputs, general priors can be handled,%	although this requires the pseudo-inverse of the Hessian, giving up%	the better conditioning and some of the speed advantage of the normal%	form equations.%%	The error function value at the final set of weights is returned in%	OPTIONS(8). Each row of X corresponds to one input vector and each%	row of T corresponds to one target vector.%%	The optional parameters have the following interpretations.%%	OPTIONS(1) is set to 1 to display error values during training. If%	OPTIONS(1) is set to 0, then only warning messages are displayed.  If%	OPTIONS(1) is -1, then nothing is displayed.%%	OPTIONS(2) is a measure of the precision required for the value of%	the weights W at the solution.%%	OPTIONS(3) is a measure of the precision required of the objective%	function at the solution.  Both this and the previous condition must%	be satisfied for termination.%%	OPTIONS(5) is set to 1 if an approximation to the Hessian (which%	assumes that all outputs are independent) is used for softmax%	outputs. With the default value of 0 the exact Hessian (which is more%	expensive to compute) is used.%%	OPTIONS(14) is the maximum number of iterations for the IRLS%	algorithm;  default 100.%%	See also%	GLM, GLMERR, GLMGRAD%%	Copyright (c) Ian T Nabney (1996-2001)% Check arguments for consistencyerrstring = consist(net, 'glm', x, t);if ~errstring  error(errstring);endif(~options(14))  options(14) = 100;enddisplay = options(1);% Do we need to test for termination?test = (options(2) | options(3));ndata = size(x, 1);% Add a column of ones for the bias inputs = [x ones(ndata, 1)];% Linear outputs are a special case as they can be found in one stepif strcmp(net.outfn, 'linear')  if ~isfield(net, 'alpha')    % Solve for the weights and biases using left matrix divide    temp = inputs\t;  elseif size(net.alpha == [1 1])    if isfield(net, 'beta')      beta = net.beta;    else      beta = 1.0;    end    % Use normal form equation    hessian = beta*(inputs'*inputs) + net.alpha*eye(net.nin+1);    temp = pinv(hessian)*(beta*(inputs'*t));    else    error('Only scalar alpha allowed');  end  net.w1 = temp(1:net.nin, :);  net.b1 = temp(net.nin+1, :);  % Store error value in options vector  options(8) = glmerr(net, x, t);  return;end% Otherwise need to use iterative reweighted least squarese = ones(1, net.nin+1);for n = 1:options(14)  switch net.outfn    case 'logistic'      if n == 1        % Initialise model        p = (t+0.5)/2;	act = log(p./(1-p));        wold = glmpak(net);      end      link_deriv = p.*(1-p);      weights = sqrt(link_deriv); % sqrt of weights      if (min(min(weights)) < eps)        warning('ill-conditioned weights in glmtrain')        return      end      z = act + (t-p)./link_deriv;      if ~isfield(net, 'alpha')         % Treat each output independently with relevant set of weights         for j = 1:net.nout	    indep = inputs.*(weights(:,j)*e);	    dep = z(:,j).*weights(:,j);	    temp = indep\dep;	    net.w1(:,j) = temp(1:net.nin);	    net.b1(j) = temp(net.nin+1);         end      else	 gradient = glmgrad(net, x, t);         Hessian = glmhess(net, x, t);         deltaw = -gradient*pinv(Hessian);         w = wold + deltaw;         net = glmunpak(net, w);      end      [err, edata, eprior, p, act] = glmerr(net, x, t);      if n == 1        errold = err;        wold = netpak(net);      else        w = netpak(net);      end    case 'softmax'      if n == 1        % Initialise model: ensure that row sum of p is one no matter	% how many classes there are        p = (t + (1/size(t, 2)))/2;	act = log(p./(1-p));      end      if options(5) == 1 | n == 1        link_deriv = p.*(1-p);        weights = sqrt(link_deriv); % sqrt of weights        if (min(min(weights)) < eps)          warning('ill-conditioned weights in glmtrain')          return        end        z = act + (t-p)./link_deriv;        % Treat each output independently with relevant set of weights        for j = 1:net.nout          indep = inputs.*(weights(:,j)*e);	  dep = z(:,j).*weights(:,j);	  temp = indep\dep;	  net.w1(:,j) = temp(1:net.nin);	  net.b1(j) = temp(net.nin+1);        end        [err, edata, eprior, p, act] = glmerr(net, x, t);        if n == 1          errold = err;          wold = netpak(net);        else          w = netpak(net);        end      else	% Exact method of calculation after w first initialised	% Start by working out Hessian	Hessian = glmhess(net, x, t);	gradient = glmgrad(net, x, t);	% Now compute modification to weights	deltaw = -gradient*pinv(Hessian);	w = wold + deltaw;	net = glmunpak(net, w);	[err, edata, eprior, p] = glmerr(net, x, t);    end    otherwise      error(['Unknown activation function ', net.outfn]);   end   if options(1)     fprintf(1, 'Cycle %4d Error %11.6f\n', n, err)   end   % Test for termination   % Terminate if error increases   if err >  errold     errold = err;     w = wold;     options(8) = err;     fprintf(1, 'Error has increased: terminating\n')     return;   end   if test & n > 1     if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3))       options(8) = err;       return;     else       errold = err;       wold = w;     end   endendoptions(8) = err;if (options(1) >= 0)  disp(maxitmess);end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -