📄 nu_rbfsvc_train.m
字号:
function [output]=nu_RbfSVC_Train(data, savefile, C, gamma, nu, options)
%options: options.crossvalidation=0 if crossvalidation is not needed.
% otherwise, options.crossvalidation=v(v>=2) for a v-fold validation
% options.probability=0 if probability estimation is not needed.
% options.probability=1 for probability estimation.
% options.shrinking=0 if shrinking is not needed.
% options.shrinking=1 for using shringking.
% options.multiclass_type=
% 0 -- one against one\n
% 1 -- one against all\n"
% 2 -- dense\n"
% 3 -- sparse\n"
% 4 -- user defined code matrix\n'
% the code matrix is stored in the file 'UserCodeMat.dat'
% with the format:
% row(number of binary classifier) col(number of class)
% row*col code matrix
% options.nr_weight=0 if no class weight need to be changed.
% options.nr_weight=1 for using different weight for classes.
% the class label and corresponding weights are specified in
% options.weight_label and options.weight.
if (nargin < 2) || (nargin > 6)
disp(' Incorrect number of input variables.\n');
help nu_RbfSVC_Train;
return;
end
if(savefile==0)
filename='CU_SVModel.dat';
else
fprintf('Please select the model file name! \n');
[name,path]=uiputfile('*.*','Save model as');
if name~=0
filename=strcat(path,name);
else
fprintf('Error open file! \n');
return;
end
end
if ~isfield(data, 'X')
disp(' Error: invalid data structure.\n');
return;
end
if ~isfield(data, 'y')
disp(' Error: invalid data structure.\n');
return;
end
if isempty(data.X)
disp(' Error: sample field is empty.\n');
return;
end
if size(data.X,2) ~= size(data.y,2)
disp(' Error: sample and label do not match.\n');
return;
end
[row, col]=size(data.X);
hfile=['CUSVM.H'];
if ~libisloaded('CUSVM')
loadlibrary('CUSVM',hfile);
end
params=libstruct('svm_parameter');
params.svm_type = 1;
params.kernel_type = 2;
params.degree = 3;
params.gamma = 1/row;
params.coef0 = 0;
params.nu = 0.5;
params.cache_size = 40;
params.C = 1;
params.eps = 1e-3;
params.p = 0.1;
params.shrinking = 1;
params.multiclass_type = 0;
params.probability = 0;
params.nr_weight = 0;
inputs=libstruct('svm_data_input');
inputs.m=row;
inputs.n=col;
inputs.labels=data.y;
inputs.samples=data.X;
crossvalidation=zeros(5,1);
if (nargin == 3)
params.C =C;
elseif (nargin == 4)
params.gamma =gamma;
params.C =C;
elseif (nargin == 5)
params.gamma =gamma;
params.C =C;
params.nu = nu;
elseif (nargin == 6)
params.gamma =gamma;
params.C =C;
params.nu = nu;
if isfield(options, 'crossvalidation')
if options.crossvalidation>=2
crossvalidation(1)=1;
crossvalidation(2)=options.crossvalidation;
if (savefile~=0)
fprintf('Since you choosed crossvalidation! \n No file will be saved!\n');
end
end
end
if isfield(options, 'probability')
params.probability =options.probability;
end
if isfield(options, 'shrinking')
params.shrinking =options.shrinking;
end
if isfield(options, 'multiclass_type')
params.multiclass_type =options.multiclass_type;
end
if isfield(options, 'nr_weight')
if (options.nr_weight~=0)&&(~isempty(options.weight_label))&&(~isempty(options.weight))
params.nr_weight =options.nr_weight;
params.weight_label=options.weight_label;
params.weight=options.weight;
end
end
end
[name, crossvalidation]=calllib('CUSVM', 'SVM_Train', inputs, params, filename, crossvalidation);
clear inputs;
clear params;
%unloadlibrary CUSVM;
if crossvalidation(1)~=0
output.Accuracy=crossvalidation(5);
output.MSE=crossvalidation(3);
output.SCC=crossvalidation(4);
else
output=Read_Model(filename);
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -