📄 svm_predict_with_stat.m
字号:
function [output, stat]=SVM_Predict_With_Stat(usemodfil, data, model, probability)
%
if (nargin < 2) || (nargin > 4)
disp(' Incorrect number of input variables.\n');
help SVM_Predict;
return;
end
if (usemodfil==0) && (nargin <3)
disp(' Since you do not want to use model file. \n You need to input a SVM model directly.\n');
help SVM_Predict;
return;
end
if ~isfield(data, 'X')
disp(' Error: invalid data structure.\n');
return;
end
if ~isfield(data, 'y')
disp(' Error: invalid data structure.\n');
return;
end
if isempty(data.X)
disp(' Error: sample field is empty.\n');
return;
end
if size(data.X,2) ~= size(data.y,2)
disp(' Error: sample and label do not match.\n');
return;
end
if (nargin <4)
probability=0;
end
if (usemodfil==1)
fprintf('Please select the model file name! \n');
[name,path]=uigetfile('*.*','Pick a SVM model file');
if name~=0
filename=strcat(path,name);
else
fprintf('Error open file! \n');
return;
end
model=Read_Model(filename);
end
hfile=['CUSVM.H'];
if ~libisloaded('CUSVM')
loadlibrary('CUSVM',hfile);
end
modelStr=libstruct('svm_model_input');
if isfield(model, 'data_dim')
modelStr.params(1)=model.data_dim;
end
if isfield(model, 'svm_type')
modelStr.params(2)=model.svm_type;
end
if isfield(model, 'kernel_type')
modelStr.params(3)=model.kernel_type;
end
if isfield(model, 'degree')
modelStr.params(4)=model.degree;
end
if isfield(model, 'gamma')
modelStr.params(5)=model.gamma;
end
if isfield(model, 'coef0')
modelStr.params(6)=model.coef0;
end
if isfield(model, 'nr_class')
modelStr.params(7)=model.nr_class;
end
if isfield(model, 'total_sv')
modelStr.params(8)=model.total_sv;
end
if isfield(model, 'multiclass_type')
modelStr.params(9)=model.multiclass_type;
end
if isfield(model, 'nr_binary')
modelStr.params(10)=model.nr_binary;
end
if isfield(model, 'rhos')
modelStr.rhos=model.rhos;
end
if isfield(model, 'labels')
modelStr.labels=model.labels;
end
if isfield(model, 'probA')
modelStr.probA=model.probA;
end
if isfield(model, 'probB')
modelStr.probB=model.probB;
end
if isfield(model, 'nSV')
modelStr.nSV=model.nSV;
end
if isfield(model, 'sv_coef')
modelStr.sv_coef=model.sv_coef;
end
if isfield(model, 'SV')
modelStr.SV=model.SV;
end
if isfield(model, 'I')
modelStr.I=model.I;
end
if isfield(model, 'nSV_binary')
modelStr.nSV_binary=model.nSV_binary;
end
if isfield(model, 'sv_ind')
modelStr.sv_ind=model.sv_ind;
end
[row, col]=size(data.X);
inputs=libstruct('svm_data_input');
inputs.m=row;
inputs.n=col;
inputs.labels=data.y;
inputs.samples=data.X;
params=zeros(4, 1);
params(1)=probability;
predlabel=zeros(col, 1);
if (probability~=0)
predvalue=zeros(col*modelStr.params(7), 1);
else
predvalue=zeros(col*modelStr.params(10), 1);
end
[output.params, output.predlabel, predvalue]=calllib('CUSVM', 'SVM_Test', inputs, modelStr, params, predlabel, predvalue);
if (probability~=0)
output.predvalue=reshape(predvalue, [col , modelStr.params(7)]);
else
output.predvalue=reshape(predvalue, [col , modelStr.params(10)]);
end
confmatrix=zeros(modelStr.params(7), modelStr.params(7));
if (modelStr.params(2)==0)||(modelStr.params(2)==1)
%nzero=find(diff==0);
%stat.classrate=length(nzero)/col;
for I=1:modelStr.params(7)
stat.Ns(I)=length(find(data.y==modelStr.labels(I)));
end
for I=1:col
idx=find(modelStr.labels==data.y(I));
if(output.predlabel(I)==data.y(I))
confmatrix(idx, idx)=confmatrix(idx, idx)+1;
else
idx1=find(modelStr.labels==output.predlabel(I));
confmatrix(idx, idx1)=confmatrix(idx, idx1)+1;
end
end
stat.classrate=0;
for I=1:modelStr.params(7)
stat.confmatrix(I, :)=confmatrix(I, :)/stat.Ns(I);
stat.classrate=stat.classrate+confmatrix(I, I);
end
stat.classrate=stat.classrate/col;
else
diff=output.predlabel-data.y';
temp=diff'*diff;
stat.MSE=temp/col;
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -