📄 dualsparsegeneralfeatureslm3train.m.svn-base
字号:
function [subspaceInfo, trainInfo] = dualSparseGeneralFeaturesLM3Train(trainData, params)
%A function to learn general features in the dual space without
%requiring the kernel matrix to be entirely in memory.
if (nargin ~= 2)
fprintf('%s\n', help(sprintf('%s', mfilename)));
error('Incorrect number of inputs - see above usage instructions.');
end
%First, figure out which variables to use in the data struct
[nameX, nameY] = getSpaceNames(params);
useSparse = issparse(getDataFieldValue(trainData, nameX));
%We really need the rank of the kernel matrix e.g. for RBF kernel, sigma = 100 can
%be nearly rank 1, but we don't compute the entire kernel matrix
if useSparse
disp('Using sparse representation');
end
[numTrainExamples, numFeatures] = getDataFieldSize(trainData, nameX);
alpha = 10^-5; %Added to the diagonal of matrices to make them non singular
tol = 10^-4;
Yj = getDataFieldValue(trainData, nameY);
%Store all the parameters
dualSparseMeasureFunction = char(params.dualSparseMeasureFunction);
T = min(params.iterations, numTrainExamples);
kernelFunction = char(params.X.kernel.name);
kernelParams = params.X.kernel.params;
cacheSize = min(params.cacheSize, numTrainExamples);
if isfield(params, 'normalise')
normaliseFeatures = params.normalise;
else
normaliseFeatures = 1;
end
%Some temp variables
numCols = min(numTrainExamples, cacheSize);
measures = zeros(numCols, 1);
normSqTau = zeros(T, 1);
tempK = zeros(numTrainExamples, T);
Kj = zeros(numTrainExamples, T);
tau = zeros(numTrainExamples, T);
KTau = zeros(cacheSize, T);
b = sparse(numTrainExamples, T);
fprintf('Using %d kernel matrix columns\n', cacheSize);
fprintf('Iterating ... \n');
tic;
%Now deflate and find new dual directions
for i=1:T
fprintf('%d ', i);
if mod(i, 30) == 0 | i == T
fprintf('\n');
end
%First lets take a sample of the kernel matrix columns
permutationVector = randperm(numTrainExamples);
rowIndices = permutationVector(1:cacheSize);
%tempX
tempK = getDataFieldValue(trainData, nameX, rowIndices);
tempK = full(feval(kernelFunction, getDataFieldValue(trainData, nameX), tempK, kernelParams));
%Now lets deflate tempK using the previous rows
KTau(:, 1:i-1) = tempK'*tau(:, 1:(i-1));
Kj = tempK - tau(:, 1:i-1)*KTau(:, 1:i-1)';
%Now select the best column of Kj according to some measure
[measures, bs] = feval(dualSparseMeasureFunction, tempK, Kj, full(getDataFieldValue(trainData, nameY)), Yj, rowIndices');
[maxMeasure, k] = max(abs(measures));
b(rowIndices(k), i) = bs(k);
tau(:, i) = Kj(:, k)*bs(k);
normSqTau(i) = tau(:, i)'*tau(:, i);
%This can occcur if the rank of the kernel matrix is very low
if (normSqTau(i) < tol)
fprintf('\nNorm of dual vector has dropped too low. Breaking out ...\n');
T = i-1;
b = b(:, 1:T);
tau = tau(:, 1:T);
normSqTau = normSqTau(1:T);
break;
end
tau(:, i) = tau(:, i)/sqrt(normSqTau(i));
%Deflate Yj
Yj = Yj - tau(:, i)*(tau(:, i)'*Yj);
end
%Get the scalings back for tau
tau = tau * diag(sqrt(normSqTau));
trainTime = toc;
fprintf('Completed in %f seconds\n', trainTime);
%Clear some very large variables
clear tempK Kj Yj;
pack;
if normaliseFeatures == 1
normMatrix = diag(1./sqrt(normSqTau));
else
normMatrix = eye(T);
end
[exampleIndices, nzElements] = findNonZeroElements(b);
diagBElements = diag(nzElements);
%trainKb
newTrainX = feval(kernelFunction, getDataFieldValue(trainData, nameX), getDataFieldValue(trainData, nameX, exampleIndices), kernelParams);
newTrainX = newTrainX*diagBElements;
Q = (tau'*newTrainX + eye(T)*alpha);
Q = (tau'*tau)\Q;
Q = inv(Q)*normMatrix; %Normalise
trainInfo = struct;
trainInfo.data = data;
trainInfo.data = addDataField(trainInfo.data, 'X', tau*normMatrix, 'examples');
subspaceInfo = struct;
subspaceInfo.trainTime = trainTime;
subspaceInfo.(nameX).b = sparse(b);
subspaceInfo.(nameX).numFeatures = T;
subspaceInfo.(nameX).exampleIndices = exampleIndices;
subspaceInfo.(nameX).nzElements = nzElements;
subspaceInfo.(nameX).Q = Q;
if isfield(params, 'verboseInfo') & params.verboseInfo == 1
subspaceInfo.X.tau = tau;
subspaceInfo.X.A = b*Q; %Projection matrix on K
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -