📄 wekaclassify.m
字号:
function [Y_compute, Y_prob] = WekaClassify(classifier, para, X_train, Y_train, X_test, Y_test, num_class)
global temp_train_file temp_test_file temp_output_file temp_model_file;
[class_set, num_class] = GetClassSet(Y_train);
[classifier, para_classifier, additional_classifier] = ParseCmd(classifier, '--');
% set up the commands
[train_test_cmd test_cmd] = WekaGenerateCMD(classifier, para, temp_train_file, temp_test_file, temp_model_file, temp_output_file, num_class, para_classifier);
num_feature = size(X_test, 2);
% make the format string
format_str = '';
for i = 1:num_feature
format_str = strcat(format_str, '%f,');
end
format_str = strcat(format_str, '%d\n');
% save the train data
if (~isempty(X_train)),
fid = fopen(temp_train_file, 'w');
fprintf(fid, '@relation train\n\n');
for attr = 1:num_feature
fprintf(fid, '@attribute a%d real\n', attr);
end
fprintf(fid, '@attribute a%d ', num_feature+1);
if (num_class ~= 0)
fprintf(fid, '{');
fprintf(fid, '%d,', class_set);
fprintf(fid, '}\n');
else
fprintf(fid, 'real');
end;
fprintf(fid, '\n@data\n');
fprintf(fid, format_str, [X_train, Y_train]');
fclose(fid);
end;
% save the test data
fid = fopen(temp_test_file, 'w');
fprintf(fid, '@relation test\n\n');
for attr = 1:num_feature
fprintf(fid, '@attribute a%d real\n', attr);
end
fprintf(fid, '@attribute a%d ', num_feature+1);
if (num_class ~= 0)
fprintf(fid, '{');
fprintf(fid, '%d,', class_set);
fprintf(fid, '}\n');
else
fprintf(fid, 'real');
end;
fprintf(fid, '\n@data\n');
fprintf(fid, format_str, [X_test, Y_test]');
fclose(fid);
% train the model
if (~isempty(X_train)),
eval(train_test_cmd);
else
eval(test_cmd);
end;
Ypred = dlmread(temp_output_file, ' ', [0 0 length(Y_test) - 1 3]);
Y_compute = int16(Ypred(:, 2));
Y_prob = Ypred(:, 3);
function [train_test_cmd, test_cmd] = WekaGenerateCMD(classifier, para, temp_train_file, temp_test_file, temp_model_file, temp_output_file, num_class, para_classifier)
global weka_dir;
%dt = 'java -classpath "d:/program files/weka-3-2/weka.jar" weka.classifiers.MultiClassClassifier';
%train_test_cmd = sprintf('!%s -W weka.classifiers.%s %s -t %s -T %s -p 0 >> %s', dt, classifier, para, temp_train_file, temp_test_file, temp_output_file);
p = str2num(char(ParseParameter(para, {'-MultiClassWrapper'}, {'-1'})));
if (p(1) < 0),
if (num_class == 2), p(1) = 0; else p(1) = 1; end;
fprintf('Automatically select MultiClassWrapper to be %d\n', p(1));
end;
if (p(1) == 0),
dt = sprintf('java -classpath "%s" ', weka_dir);
elseif (p(1) == 1),
dt = sprintf('java -classpath "%s" weka.classifiers.meta.MultiClassClassifier -W ', weka_dir);
end;
if (isempty(findstr(dt, 'meta'))),
train_test_cmd = sprintf('!%s weka.classifiers.%s -t %s -T %s -d %s -p 0 %s > %s', dt, classifier, temp_train_file, temp_test_file, temp_model_file, char(para_classifier), temp_output_file);
else
train_test_cmd = sprintf('!%s weka.classifiers.%s -t %s -T %s -d %s -p 0 -- %s > %s', dt, classifier, temp_train_file, temp_test_file, temp_model_file, char(para_classifier), temp_output_file);
end;
if (isempty(findstr(dt, 'meta'))),
test_cmd = sprintf('!%s weka.classifiers.%s -T %s -l %s -p 0 %s > %s', dt, classifier, temp_test_file, temp_model_file, char(para_classifier), temp_output_file);
else
test_cmd = sprintf('!%s weka.classifiers.%s -T %s -l %s -p 0 -- %s > %s', dt, classifier, temp_test_file, temp_model_file, char(para_classifier), temp_output_file);
end;
% fprintf(train_test_cmd);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -