⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 dualsparsedoublefeaturestrain.m.svn-base

📁 a function inside machine learning
💻 SVN-BASE
字号:
function  [subspaceInfo, trainInfo] = dualSparseDoubleFeaturesTrain(trainData, params)
%A function to train a CCA-like algorithm which generates sparse projection
%directions and deflation in a CCA manner 

if (nargin ~= 2)
    fprintf('%s\n', help(sprintf('%s', mfilename)));
    error('Incorrect number of inputs - see above usage instructions.');
end

%First, figure out which variables to use in the data struct 
[nameX, nameY] = getSpaceNames(params); 

useSparseX = issparse(getDataFieldValue(trainData, nameX)); 
useSparseY = issparse(getDataFieldValue(trainData, nameY)); 

if useSparseX | useSparseY  
    disp('Using sparse representation'); 
end 

[numTrainExamples, numXFeatures] = getDataFieldSize(trainData, nameX);
[numTrainExamples, numYFeatures] = getDataFieldSize(trainData, nameY);
alpha = 10^-5; %Added to the diagonal of matrices to make them non singular
tol = 10^-4; 

%Store all the parameters 
dualSparseMeasureFunction = char(params.dualSparseMeasureFunction); 
T = min(params.iterations, numTrainExamples);
kernelFunctionX = char(params.X.kernel.name); 
kernelParamsX = params.X.kernel.params; 
kernelFunctionY = char(params.Y.kernel.name); 
kernelParamsY = params.Y.kernel.params; 
numKernelCols = min(params.numKernelCols, numTrainExamples);

if isfield(params, 'normalise')
    normaliseFeatures = params.normalise;
else 
    normaliseFeatures = 1; 
end

%Some temp variables
measures = zeros(numKernelCols, 1); 
normSqTauX = zeros(T, 1);
normSqTauY = zeros(T, 1);

tempKX = zeros(numTrainExamples, T);  
KXj  = zeros(numTrainExamples, T);  

tempKY = zeros(numTrainExamples, T);  
KYj  = zeros(numTrainExamples, T);  

if useSparseX
    initMatrixFunctionX = 'sparse';
else
    initMatrixFunctionX = 'zeros';
end

tauX = feval(initMatrixFunctionX, numTrainExamples, T);
KXTau = feval(initMatrixFunctionX, numKernelCols, T);
bX = feval(initMatrixFunctionX, numTrainExamples, T);

if useSparseY
    initMatrixFunctionY = 'sparse';
else
    initMatrixFunctionY = 'zeros';
end

tauY = feval(initMatrixFunctionY, numTrainExamples, T);
KYTau = feval(initMatrixFunctionY, numKernelCols, T);
bY = feval(initMatrixFunctionY, numTrainExamples, T);

fprintf('Using %d kernel matrix columns on %d examples\n', numKernelCols, numTrainExamples); 
fprintf('Iterating ... \n');
tic; 

%Now deflate and find new dual directions
for i=1:T
    fprintf('%d ', i);
    
    if mod(i, 30) == 0 | i == T
        fprintf('\n');
    end 
    
    %First lets take a sample of the kernel matrix columns
    permutationVectorX = randperm(numTrainExamples);
    rowIndicesX = permutationVectorX(1:numKernelCols); 
    
    permutationVectorY = randperm(numTrainExamples);
    rowIndicesY = permutationVectorY(1:numKernelCols);
    
    %tempX 
    tempKX = getDataFieldValue(trainData, nameX, rowIndicesX);
    tempKX = feval(kernelFunctionX, getDataFieldValue(trainData, nameX), tempKX, kernelParamsX);
    
    %tempY
    tempKY = getDataFieldValue(trainData, nameY, rowIndicesY);
    tempKY = feval(kernelFunctionY, getDataFieldValue(trainData, nameY), tempKY, kernelParamsY);

    %Now lets deflate tempK using the previous rows
    KXTau(:, 1:i-1) = tempKX'*tauX(:, 1:(i-1));
    KXj = tempKX - tauX(:, 1:i-1)*KXTau(:, 1:i-1)';
    
    KYTau(:, 1:i-1) = tempKY'*tauY(:, 1:(i-1));
    KYj = tempKY - tauY(:, 1:i-1)*KYTau(:, 1:i-1)';
   
    %Now select the best column of Kj according to some measure 
    [measures, bXs, bYs] = feval(dualSparseMeasureFunction, tempKX, KXj, tempKY, KYj, rowIndicesX, rowIndicesY);
    [indexX, indexY] = maxInMatrix(abs(measures));
    
    bX(rowIndicesX(indexX), i) = bXs(indexX);
    bY(rowIndicesY(indexY), i) = bYs(indexY);
    
    tauX(:, i) = KXj(:, indexX)*bXs(indexX);
    tauY(:, i) = KYj(:, indexY)*bYs(indexY);
    normSqTauX(i) = tauX(:, i)'*tauX(:, i);
    normSqTauY(i) = tauY(:, i)'*tauY(:, i); 
    
    %This can occcur if the rank of the kernel matrix is very low 
    if (normSqTauX(i) < tol) | (normSqTauY(i) < tol)
        fprintf('\nNorm of dual vector has dropped too low. Breaking out ...\n');
        T = i-1;
        bX = bX(:, 1:T); 
        bY = bY(:, 1:T); 
        tauX = tauX(:, 1:T);
        tauY = tauY(:, 1:T);
        normSqTauX = normSqTauX(1:T);
        normSqTauY = normSqTauY(1:T);
        break;
    end
    
    tauX(:, i) = tauX(:, i)/sqrt(normSqTauX(i)); 
    tauY(:, i) = tauY(:, i)/sqrt(normSqTauY(i)); 
end 

%Get the scalings back for tau 
tauX = tauX * diag(sqrt(normSqTauX)); 
tauY = tauY * diag(sqrt(normSqTauY)); 
trainTime = toc; 

fprintf('Completed in %f seconds\n', trainTime); 

clear KXTau KYTau tempKX tempKY KXj KYj;
pack; 

if normaliseFeatures == 1
    normMatrixX = diag(1./sqrt(normSqTauX));
    normMatrixY = diag(1./sqrt(normSqTauY));
else
    normMatrixX = eye(T);
    normMatrixY = eye(T);
end

[exampleXIndices, nzXElements] = findNonZeroElements(bX);
[exampleYIndices, nzYElements] = findNonZeroElements(bY);
diagBXElements = diag(nzXElements);
diagBYElements = diag(nzYElements);

%trainKXb
trainInfo.newX = feval(kernelFunctionX, getDataFieldValue(trainData, nameX), getDataFieldValue(trainData, nameX, exampleXIndices), kernelParamsX)*diagBXElements;
QX = (tauX'*trainInfo.newX + eye(T)*alpha);
QX = (tauX'*tauX)\QX; 
QX = inv(QX)*normMatrixX;  %Normalise 

%trainKYb
trainInfo.newY = feval(kernelFunctionY, getDataFieldValue(trainData, nameY), getDataFieldValue(trainData, nameY, exampleYIndices), kernelParamsY)*diagBYElements;
QY = (tauY'*trainInfo.newY + eye(T)*alpha);
QY = (tauY'*tauY)\QY; 
QY = inv(QY)*normMatrixY;  %Normalise 

trainInfo = struct; 
trainInfo.data = data; 
trainInfo.data = addDataField(trainInfo.data, 'X', tauX*normMatrixX, 'examples'); 
trainInfo.data = addDataField(trainInfo.data, 'Y', tauY*normMatrixY, 'examples'); 

%Store some information 
subspaceInfo = struct; 
subspaceInfo.(nameX).b = sparse(bX); 
subspaceInfo.(nameY).b = sparse(bY); 
subspaceInfo.(nameX).Q = QX; 
subspaceInfo.(nameY).Q = QY; 
subspaceInfo.(nameX).exampleIndices = exampleXIndices;
subspaceInfo.(nameY).exampleIndices = exampleYIndices;
subspaceInfo.(nameX).nzElements = nzXElements;
subspaceInfo.(nameY).nzElements = nzYElements;
subspaceInfo.numFeatures = T; 
subspaceInfo.trainTime = trainTime; 

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -