📄 lssvm_demo.m
字号:
function lssvm_demo
%
% LSSVM_DEMO - demonstrate l-o-o model selection for the LS-SVM
%
% LSSVM_LOO demonstrates the use of leave-one-out cross-validation to
% select the hyper-parameters (i.e. the regularisation and kernel
% parameters) for a least-squares support vector machine.
%
%
% File : lssvm_demo.m
%
% Date : Saturday 6th January 2007.
%
% Author : Dr Gavin C. Cawley
%
% Description : Simple demonstration of model selection for least-squares
% support vecor machines [1] by minimisation of the leave-one-out
% cross-validation error, i.e. Allen's PRESS statistic [2,3].
% The PRESS statistic is minimised using a simple Nelder-Mead
% simplex optimiser [4].
%
% References : [1] Suykens, J. A. K., Van Gestel, T., De Brabanter, J.,
% De Moor, B. and Vanderwalle, J., "Least Squares Support
% Vector Machines", World Scientific Publishing, 2002.
%
% [2] Allen, D. M., "The relationship between variable selection
% and prediction", Technometrics, vol. 16, pp. 125-127, 1974.
%
% [3] Cawley, G. C., "Leave-one-out cross-validation based model
% selection criteria for weighted LS-SVMs", In Proceedings
% of the International Joint Conference on Neural Networks
% (IJCNN-2006)", pp. 2970-2977, Vancouver, BC, Canada,
% July 16-21 2006.
%
% [4] J. A. Nelder and R. Mead, "A simplex method for function
% minimization", Computer Journal, 7:308-313, 1965.
%
% History : 06/01/2007 - v1.00
%
% Copyright : (c) Dr Gavin C. Cawley, January 2007.
%
% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 2 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program; if not, write to the Free Software
% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
%
% generate synthetic training and test data
[x_train,t_train] = synthetic(128)
[x_test, t_test] = synthetic(8192);
ntp = length(t_train);
% perform model selection using PRESS and Nelder-Mead simplex
fprintf(1, 'performing model selection...\n');
opts = simplex;
opts.TolFun = 1e-6;
opts.TolX = 1e-6;
theta = [-4;4]; % default parameters (result in over-fitting)
theta = simplex(@press, theta, opts, x_train, t_train);
lambda = 2^theta(1); % regularisation parameter
eta = 2^theta(2); % kernel parameter
% train final model
fprintf(1, 'training final model...\n');
[L,alpha,b] = press(theta, x_train, t_train);
% draw a pretty picture
fprintf(1, 'plotting decision boundary...\n');
figure(1);
clf;
set(axes, 'FontSize', 12);
h1 = plot(x_train(t_train == +1, 1), x_train(t_train == +1, 2), 'r+');
hold on;
h2 = plot(x_train(t_train == -1, 1), x_train(t_train == -1, 2), 'go');
a = axis;
[X,Y] = meshgrid(a(1):0.02:a(2),a(3):0.02:a(4));
y = rbf(eta, [X(:) Y(:)], x_train)*alpha + b;
y = reshape(y, size(X));
hold on
[c,h3] = contour(X, Y, y, [+1.0 +1.0], 'r--');
[c,h4] = contour(X, Y, y, [+0.0 +0.0], 'b-');
[c,h5] = contour(X, Y, y, [-1.0 -1.0], 'g-.');
hold off
handles = [h1 ; h2 ; h3(1) ; h4 ; h5];
legend(handles, 'class 1', 'class 2', 'p = 0.1', 'p = 0.5', 'p = 0.9', 'Location', 'NorthWest');
drawnow;
% evaluate performance on test and training data
y_train = rbf(eta, x_train, x_train)*alpha + b;
fprintf(1, 'training error = %6.2f%%\n', 100*mean((y_train>0)~=(t_train>0)));
y_test = rbf(eta, x_test, x_train)*alpha + b;
fprintf(1, 'test error = %6.2f%%\n', 100*mean((y_test>0)~=(t_test>0)));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% %
% SUBFUNCTIONS %
% %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [x,t] = synthetic(ntp)
%
% SYNTHETIC - generate synthetic benchmark data
x = [sqrt(0.03)*randn(ceil(ntp/4),2)+repmat([+0.4 +0.7],ceil(ntp/4),1);...
sqrt(0.03)*randn(ceil(ntp/4),2)+repmat([-0.3 +0.7],ceil(ntp/4),1);...
sqrt(0.03)*randn(ceil(ntp/4),2)+repmat([-0.7 +0.3],ceil(ntp/4),1);...
sqrt(0.03)*randn(ceil(ntp/4),2)+repmat([+0.3 +0.3],ceil(ntp/4),1)];
t = [+ones(ceil(ntp/4),1);...
+ones(ceil(ntp/4),1);...
-ones(ceil(ntp/4),1);...
-ones(ceil(ntp/4),1)];
% randomise order of training patterns
idx = randperm(length(t));
x = x(idx,:);
t = t(idx);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function K = rbf(eta, x1, x2)
%
% RBF - evaluate radial basis function (RBF) kernel
ones1 = ones(size(x1, 1), 1);
ones2 = ones(size(x2, 1), 1);
K = exp(-eta*(sum(x1.^2,2)*ones2' + ones1*sum(x2.^2,2)' - 2*x1*x2'));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [L,alpha,b] = press(theta, x, t)
%
% PRESS - evaluate hyper-parameters using Allen's PRESS statistic
[alpha,b,r] = train(theta, x, t);
L = mean(r.^2);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [alpha,b,r] = train(theta, x, t)
%
% TRAIN - train a least-squares support vector machine
% re-parameterise strictly positive hyper-parameters
lambda = 2^theta(1); % regularisation parameter
eta = 2^theta(2); % kernel parameter
% evaluate the kernel matrix
K = rbf(eta, x, x);
% train least-squares support vector machine
ntp = size(x,1);
R = chol(K + lambda*eye(ntp));
xi = R\(R'\[t ones(ntp,1)]);
zeta = xi(:,1);
xi = xi(:,2);
oneoversumxi = 1/sum(xi);
b = oneoversumxi*sum(zeta);
alpha = zeta - xi*b;
% evaluate the model selection criterion
if nargout > 2
Ri = inv(R);
Cii = sum(Ri.^2,2) - oneoversumxi*xi.^2;
r = alpha./Cii;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [x,y,X,Y] = simplex(func, x, opts, varargin)
%
% SIMPLEX - multidimensional unconstrained non-linear optimsiation
%
% X = SIMPLEX(FUNC,X) finds a local minumum of a function, via a function
% handle FUNC, starting from an initial point X. The local minimum is
% located via the Nelder-Mead simplex algorithm [1], which does not require
% any gradient information.
%
% [X,Y] = SIMPLEX(FUNC,X) also returns the value of the function, Y, at
% the local minimum, X.
%
% X = SIMPLEX(FUNC,X,OPTS) allows the optimisation parameters to be
% specified via a structure, OPTS, with members
%
% opts.Chi - Parameter governing expansion steps
% opts.Delta - Parameter governing size of initial simplex.
% opts.Gamma - Parameter governing contraction steps.
% opts.Rho - Parameter governing reflection steps.
% opts.Sigma - Parameter governing shrinkage steps.
% opts.MaxIter - Maximum number of optimisation steps.
% opts.MaxFunEvals - Maximum number of function evaluations.
% opts.TolFun - Stopping criterion based on the relative change in
% value of the function in each step.
% opts.TolX - Stopping criterion based on the change in the
% minimiser in each step.
%
% OPTS = SIMPLEX() returns a structure containing the default optimisation
% parameters, with the following values:
%
% opts.Chi = 2
% opts.Delta = 0.01
% opts.Gamma = 0.5
% opts.Rho = 1
% opts.Sigma = 0.5
% opts.MaxIter = 200
% opts.MaxFunEvals = 1000
% opts.TolFun = 1e-3
% opts.TolX = 1e-3
%
% X = SIMPLEX(FUNC,X,OPTS, P1, P2, ...) allows addinal parameters to be
% passed to the function to be minimised.
%
% [X,Y,XX,YY] = SIMPLEX(FUNC, X) also returns in XX all of the values of
% X evaluated during the optimisation process and in YY the corresponding
% values of the function.
%
% References:
%
% [1] J. A. Nelder and R. Mead, "A simplex method for function
% minimization", Computer Journal, 7:308-313, 1965.
if nargin < 3
opts.Chi = 2;
opts.Delta = 0.01;
opts.Gamma = 0.5;
opts.Rho = 1;
opts.Sigma = 0.5;
opts.MaxIter = 200;
opts.MaxFunEvals = 1000;
opts.TolFun = 1e-3;
opts.TolX = 1e-3;
end
if nargin == 0
x = opts;
return
end
% get initial parameters
n = length(x);
x = repmat(x', n+1, 1);
y = zeros(n+1, 1);
% form initial simplex
for i=1:n
x(i,i) = x(i,i) + opts.Delta;
y(i) = func(x(i,:), varargin{:});
end
y(n+1) = func(x(n+1,:), varargin{:});
X = x;
Y = y;
count = n+1;
format = ' % 4d % 4d % 12f %s\n';
fprintf(1, '\n Iteration Func-count min f(x) Procedure\n\n');
fprintf(1, format, 1, count, min(y), 'initial');
% iterative improvement
for i=2:opts.MaxIter
% order
[y,idx] = sort(y);
x = x(idx,:);
% reflect
centroid = mean(x(1:end-1,:));
x_r = centroid + opts.Rho*(centroid - x(end,:));
y_r = func(x_r, varargin{:});
count = count + 1;
X = [X ; x_r];
Y = [Y ; y_r];
if y_r >= y(1) & y_r < y(end-1)
% accept reflection point
x(end,:) = x_r;
y(end) = y_r;
fprintf(1, format, i, count, min(y), 'reflect');
else
if y_r < y(1)
% expand
x_e = centroid + opts.Chi*(x_r - centroid);
y_e = func(x_e, varargin{:});
count = count + 1;
X = [X ; x_e];
Y = [Y ; y_e];
if y_e < y_r
% accept expansion point
x(end,:) = x_e;
y(end) = y_e;
fprintf(1, format, i, count, min(y), 'expand');
else
% accept reflection point
x(end,:) = x_r;
y(end) = y_r;
fprintf(1, format, i, count, min(y), 'reflect');
end
else
% contract
shrink = 0;
if y(end-1) <= y_r & y_r < y(end)
% contract outside
x_c = centroid + opts.Gamma*(x_r - centroid);
y_c = func(x_c, varargin{:});
count = count + 1;
X = [X ; x_c];
Y = [Y ; y_c];
if y_c <= y_r
% accept contraction point
x(end,:) = x_c;
y(end) = y_c;
fprintf(1, format, i, count, min(y), 'contract outside');
else
shrink = 1;
end
else
% contract inside
x_c = centroid + opts.Gamma*(centroid - x(end,:));
y_c = func(x_c, varargin{:});
count = count + 1;
X = [X ; x_c];
Y = [Y ; y_c];
if y_c <= y(end)
% accept contraction point
x(end,:) = x_c;
y(end) = y_c;
fprintf(1, format, i, count, min(y), 'contract inside');
else
shrink = 1;
end
end
if shrink
% shrink
for j=2:n+1
x(j,:) = x(1,:) + opts.Sigma*(x(j,:) - x(1,:));
y(j) = func(x(j,:), varargin{:});
count = count + 1;
X = [X ; x(j,:)];
Y = [Y ; y(j)];
end
fprintf(1, format, i, count, min(y), 'shrink');
end
end
end
% evaluate stopping criterion
if max(abs(min(x) - max(x))) < opts.TolX
fprintf(1, 'optimisation terminated sucessfully (TolX criterion)\n');
break;
end
if abs(max(y) - min(y))/max(abs(y)) < opts.TolFun
fprintf(1, 'optimisation terminated sucessfully (TolFun criterion)\n');
break;
end
end
if i == opts.MaxIter
fprintf(1, 'Warning : maximim number of iterations exceeded\n');
end
% update model structure
[y, idx] = min(y);
x = x(idx,:);
% bye bye...
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -