📄 lshssvmtrain.m
字号:
function [alpha,bias, svi, nsv] = lshsvmtrain(samplesX,samplesY,kernel,kernelparam,C,threshold,percent)
if (nargin ~=7) % check correct number of arguments
help svc
else
fprintf('Least Square Hidden Space Support Vector Machines\n')
fprintf('__________________________________________________\n')
n = size(samplesX,1);
% Construct the Kernel matrix
fprintf('Constructing ...\n');
K = zeros(n,n);
st = cputime;
for i=1:n
K(:,i) = bsvkernel(samplesX,samplesX(i*ones(n,1),:),kernel,kernelparam);
end
% Solve the Optimisation Problem
fprintf('Optimising ...\n');
if threshold
H = [(K'*K)+eye(size(K))/2/C K*ones(n,1); ones(1,n)*K n];
c = -[K*samplesY; samplesY'*ones(n,1)];
R = inv(H);
alpha = R*(-c);
svi = 1:n;
error = percent*(-c'*(alpha));
perror = 0;
for i=1:n
tempR = diag(R);
tempR(end) = [];
[temp,i_out] = min(alpha(svi).^2./tempR);
perror = perror + temp;
if perror > error
break;
end
tempsvi = [svi,n+1];
alpha(tempsvi) = alpha(tempsvi) - alpha(svi(i_out))*R(:,i_out)/R(i_out,i_out);
svi(i_out) = [];
R = removeone(R,i_out);
end
nsv = length(svi);
bias = alpha(end);
alpha = alpha(1:n);
else
H = (K'*K)+eye(size(K))/2/C;
c = -K*samplesY;
R = inv(H);
alpha = R*(-c);
bias = 0;
error = percent*(-c'*(alpha));
perror = 0;
svi = 1:n;
for i=1:n
tempR = diag(R);
[temp,i_out] = min(alpha(svi).^2./tempR);
perror = perror + temp;
if perror > error
break;
end
alpha(svi) = alpha(svi) - alpha(svi(i_out))*R(:,i_out)/R(i_out,i_out);
svi(i_out) = [];
R = removeone(R,i_out);
end
nsv = length(svi);
end
% KK = K(:,svi);
% [U,S] = svd(KK);
% nn = 13;
% lamdavec = 2.^linspace(0,-12,nn);
% for i = 1:nn
% lamda = lamdavec(i);
% diagD = diag(S.^2)./diag(S.^2 + lamda);
% diagD = [diagD; zeros(n-nsv,1)];
% switch 'cv'
% case 'cv'
% Tr = 1 - sum(U.*(ones(n,1)*diagD').*U,2);
% Error = samplesY - U*(diagD.*(U'*samplesY));
% loovec(i) = mean((Error./Tr).^2);
% case 'gcv'
% Error = samplesY - U*(diagD.*(U'*samplesY));
% p = sum(diagD)/n;
% loovec(k,i) = mean(Error.^2)/(1-p).^2;
% case 'aic'
% Error = samplesY - U*(diagD.*(U'*samplesY));
% p = sum(diagD)/n;
% loovec(k,i) = mean(Error.^2)*(1+p)/(1-p);
% case 'bic'
% Error = samplesY - U*(diagD.*(U'*samplesY));
% p = sum(diagD)/n;
% loovec(k,i) = mean(Error.^2)*(1+log(n)/2*p/(1-p));
% case 'sms'
% Error = samplesY - U*(diagD.*(U'*samplesY));
% p = sum(diagD)/n;
% loovec(k,i) = mean(Error.^2)*(1+2*p);
% case 'vm'
% Error = samplesY - U*(diagD.*(U'*samplesY));
% p = (sum(diagD))/n;
% factor = 1;
% loovec(k,i) = mean(Error.^2)/(1-factor*sqrt(p-p*log(p)+log(n)/2/n));
% end
% end
% fprintf('Execution time : %4.1f seconds\n',cputime - st);
% [looe,index1] = min(loovec);
% lamda = lamdavec(index1);
% alpha(svi) = (KK'*KK + lamda*eye(nsv,nsv))\(KK'*samplesY);
fprintf('Execution time: %4.1f seconds\n',cputime - st);
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -