📄 wekaclassify.m
字号:
% WekaClassify: implementation for weka classification
%
% Parameters:
% para: parameters
% 1. MultiClassWrapper: use multi-class wrapper or not, default: -1
% (automatically detected)
% X_train: training examples
% Y_train: training labels
% X_test: testing examples
% Y_test: testing labels
% num_class: number of classes
% class_set: set of class labels such as [1,-1], the first one is the
% positive label
%
% Output parameters:
% Y_compute: the predicted labels
% Y_prob: the prediction confidence in [0,1]
%
% Require functions:
% ParseParameter, GetModelFilename
function [Y_compute, Y_prob] = WekaClassify(classifier, para, X_train, Y_train, X_test, Y_test, num_class, class_set)
global temp_train_file temp_test_file temp_output_file;
temp_model_file = GetModelFilename;
[classifier, para_classifier, additional_classifier] = ParseCmd(classifier, '--');
% set up the commands
[train_test_cmd test_cmd] = WekaGenerateCMD(classifier, para, temp_train_file, temp_test_file, temp_model_file, temp_output_file, num_class, para_classifier);
num_feature = size(X_test, 2);
% make the format string
format_str = '';
for i = 1:num_feature
format_str = strcat(format_str, '%f,');
end
format_str = strcat(format_str, '%d\n');
% save the train data
if (~isempty(X_train)),
fid = fopen(temp_train_file, 'w');
fprintf(fid, '@relation train\n\n');
for attr = 1:num_feature
fprintf(fid, '@attribute a%d real\n', attr);
end
fprintf(fid, '@attribute a%d ', num_feature+1);
if (num_class ~= 0)
fprintf(fid, '{');
fprintf(fid, '%d,', class_set);
fprintf(fid, '}\n');
else
fprintf(fid, 'real');
end;
fprintf(fid, '\n@data\n');
fprintf(fid, format_str, full([X_train, Y_train]'));
fclose(fid);
end;
% save the test data
fid = fopen(temp_test_file, 'w');
fprintf(fid, '@relation test\n\n');
for attr = 1:num_feature
fprintf(fid, '@attribute a%d real\n', attr);
end
fprintf(fid, '@attribute a%d ', num_feature+1);
if (num_class ~= 0)
fprintf(fid, '{');
fprintf(fid, '%d,', class_set);
fprintf(fid, '}\n');
else
fprintf(fid, 'real');
end;
fprintf(fid, '\n@data\n');
fprintf(fid, format_str, full([X_test, Y_test]'));
fclose(fid);
% train/test the model
if (~isempty(X_train)),
eval(train_test_cmd);
else
eval(test_cmd);
end;
Ypred = dlmread(temp_output_file, ' ', [0 0 length(Y_test) - 1 3]);
Y_compute = int16(Ypred(:, 2));
Y_prob = Ypred(:, 3);
if (num_class == 2),
Y_prob = Y_prob .* (Y_compute == class_set(1)) + (1 - Y_prob) .* (Y_compute ~= class_set(1));
end;
function [train_test_cmd, test_cmd] = WekaGenerateCMD(classifier, para, temp_train_file, temp_test_file, temp_model_file, temp_output_file, num_class, para_classifier)
global weka_dir;
p = str2num(char(ParseParameter(para, {'-MultiClassWrapper'}, {'-1'}, 1)));
if (p(1) < 0),
if (num_class == 2), p(1) = 0; else p(1) = 1; end;
fprintf('Automatically select MultiClassWrapper to be %d\n', p(1));
end;
if (p(1) == 0),
dt = sprintf('java -classpath "%s" ', weka_dir);
elseif (p(1) == 1),
dt = sprintf('java -classpath "%s" weka.classifiers.meta.MultiClassClassifier -W ', weka_dir);
end;
if (isempty(findstr(dt, 'meta'))),
train_test_cmd = sprintf('!%s weka.classifiers.%s -t %s -T %s -d %s -p 0 %s > %s', dt, classifier, temp_train_file, temp_test_file, temp_model_file, char(para_classifier), temp_output_file);
else
train_test_cmd = sprintf('!%s weka.classifiers.%s -t %s -T %s -d %s -p 0 -- %s > %s', dt, classifier, temp_train_file, temp_test_file, temp_model_file, char(para_classifier), temp_output_file);
end;
if (isempty(findstr(dt, 'meta'))),
test_cmd = sprintf('!%s weka.classifiers.%s -T %s -l %s -p 0 %s > %s', dt, classifier, temp_test_file, temp_model_file, char(para_classifier), temp_output_file);
else
test_cmd = sprintf('!%s weka.classifiers.%s -T %s -l %s -p 0 -- %s > %s', dt, classifier, temp_test_file, temp_model_file, char(para_classifier), temp_output_file);
end;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -