libsvm.m
来自「一款数据挖掘的软件」· M 代码 · 共 106 行
M
106 行
% libSVM: wrapper for libSVM
%
% Parameters:
% para: parameters
% 1. Kernel: kernel type, 0: linear, 1: poly, 2: RBF, default: 0
% 2. KernelParam: kernel parameter, default: 0.05
% 3. CostFactor: weighting between postive data and negative data, default: 1
% 4. Threshold: decision threshold for posterior probability, default: 0.5
% X_train: training examples
% Y_train: training labels
% X_test: testing examples
% Y_test: testing labels
% num_class: number of classes
% class_set: set of class labels such as [1,-1], the first one is the
% positive label
%
% Output parameters:
% Y_compute: the predicted labels
% Y_prob: the prediction confidence in [0,1]
%
% Require functions:
% ParseParameter, GetModelFilename
function [Y_compute, Y_prob] = libSVM(para, X_train, Y_train, X_test, Y_test, num_class, class_set)
global temp_train_file temp_test_file temp_output_file libSVM_dir;
global preprocess;
verbosity = preprocess.Verbosity;
p = str2num(char(ParseParameter(para, {'-Kernel';'-KernelParam'; '-CostFactor'; '-Threshold'}, {'0';'0.05';'1';'0.5'})));
switch p(1)
case 0
s = '';
case 1
s = sprintf('-d %.10g -g 1', p(2));
case 2
s = sprintf('-g %.10g', p(2));
case 3
s = sprintf('-r %.10g', p(2));
case 4
s = sprintf('-u "%s"', p(2));
end
% set up the commands
temp_model_file = GetModelFilename;
train_cmd = sprintf('!%s/svmtrain -b 1 -s 0 -t %d %s -c %f %s %s > log', libSVM_dir, p(1), s, p(3), temp_train_file, temp_model_file);
test_cmd = sprintf('!%s/svmpredict -b 1 %s %s %s >> log', libSVM_dir, temp_test_file, temp_model_file, temp_output_file);
[num_train_data, num_feature] = size(X_train);
[num_test_data, num_feature] = size(X_test);
if (~isempty(X_train)),
% Parameter Estimation
fid = fopen(temp_train_file, 'w');
for i = 1:num_train_data,
Xi = X_train(i, :);
% Write label as the first entry
s = sprintf('%d ', Y_train(i));
% Then follow 'feature:value' pairs
ind = find(Xi);
s = [s sprintf(['%i:' '%.10g' ' '], [ind' full(Xi(ind))']')];
fprintf(fid, '%s\n', s);
end
fclose(fid);
% Train the SVM
if (verbosity >= 1), fprintf('Running: %s........\n', train_cmd); end;
eval(train_cmd);
end;
% Prediction
fid = fopen(temp_test_file, 'w');
for i = 1:num_test_data,
Xi = X_test(i, :);
% Write label as the first entry
s = sprintf('%d ', Y_test(i));
% Then follow 'feature:value' pairs
ind = find(Xi);
s = [s sprintf(['%i:' '%.10g' ' '], [ind' full(Xi(ind))']')];
fprintf(fid, '%s\n', s);
end
fclose(fid);
if (verbosity >= 1), fprintf('Running: %s........\n', test_cmd); end;
eval(test_cmd);
% Read the result file
fid = fopen(temp_output_file, 'r');
line = fgets(fid);
labels = sscanf(line(8:length(line)), '%d');
Y = fscanf(fid, '%f');
fclose(fid);
Y = (reshape(Y, [], num_test_data))';
Y_compute = int16(Y(:, 1));
% convert the labels into probability
threshold = p(4);
if (length(labels) == 2),
% binary class
Y_prob = Y(:, find(labels == 1) + 1);
if (threshold ~= 0.5),
Y_compute = class_set(1) * (Y_prob >= threshold) + class_set(2) * (Y_prob < threshold);
end;
else
% multiple classes
Y_prob = max(Y(:, 2:size(Y,2)), [], 2);
end;
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?