⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_predict_with_stat.m

📁 CU SVM Classifier Matlab Toolbox
💻 M
字号:
function [output, stat]=SVM_Predict_With_Stat(usemodfil, data, model, probability)
%

if (nargin < 2) || (nargin > 4)
   disp(' Incorrect number of input variables.\n');
   help SVM_Predict;
   return;
end

if (usemodfil==0) && (nargin <3)
   disp(' Since you do not want to use model file. \n You need to input a SVM model directly.\n');
   help SVM_Predict;
   return;
end

if ~isfield(data, 'X')
    disp(' Error: invalid data structure.\n');
    return; 
end
if ~isfield(data, 'y')
    disp(' Error: invalid data structure.\n');
    return; 
end
if isempty(data.X)
    disp(' Error: sample field is empty.\n');
    return; 
end
if size(data.X,2) ~= size(data.y,2)
    disp(' Error: sample and label do not match.\n');
    return; 
end

if (nargin <4)
    probability=0;
end

if (usemodfil==1)
   fprintf('Please select the model file name! \n');
   [name,path]=uigetfile('*.*','Pick a SVM model file');
   if name~=0
       filename=strcat(path,name);
   else
       fprintf('Error open file! \n');
       return;
   end
   model=Read_Model(filename);
end

hfile=['CUSVM.H'];
if ~libisloaded('CUSVM')
    loadlibrary('CUSVM',hfile);
end

modelStr=libstruct('svm_model_input');
if isfield(model, 'data_dim')
    modelStr.params(1)=model.data_dim;
end
if isfield(model, 'svm_type')
    modelStr.params(2)=model.svm_type;
end
if isfield(model, 'kernel_type')
    modelStr.params(3)=model.kernel_type;
end
if isfield(model, 'degree')
    modelStr.params(4)=model.degree;
end
if isfield(model, 'gamma')
    modelStr.params(5)=model.gamma;
end
if isfield(model, 'coef0')
    modelStr.params(6)=model.coef0;
end
if isfield(model, 'nr_class')
    modelStr.params(7)=model.nr_class;
end
if isfield(model, 'total_sv')
    modelStr.params(8)=model.total_sv;
end
if isfield(model, 'multiclass_type')
    modelStr.params(9)=model.multiclass_type;
end
if isfield(model, 'nr_binary')
    modelStr.params(10)=model.nr_binary;
end
if isfield(model, 'rhos')
    modelStr.rhos=model.rhos;
end
if isfield(model, 'labels')
    modelStr.labels=model.labels;
end
if isfield(model, 'probA')
    modelStr.probA=model.probA;
end
if isfield(model, 'probB')
    modelStr.probB=model.probB;
end
if isfield(model, 'nSV')
    modelStr.nSV=model.nSV;
end
if isfield(model, 'sv_coef')
    modelStr.sv_coef=model.sv_coef;
end
if isfield(model, 'SV')
    modelStr.SV=model.SV;
end
if isfield(model, 'I')
    modelStr.I=model.I;
end
if isfield(model, 'nSV_binary')
    modelStr.nSV_binary=model.nSV_binary;
end
if isfield(model, 'sv_ind')
    modelStr.sv_ind=model.sv_ind;
end

[row, col]=size(data.X);
inputs=libstruct('svm_data_input');
inputs.m=row;
inputs.n=col;
inputs.labels=data.y;
inputs.samples=data.X;

params=zeros(4, 1);
params(1)=probability;
predlabel=zeros(col, 1);
if (probability~=0)
    predvalue=zeros(col*modelStr.params(7), 1);
else
    predvalue=zeros(col*modelStr.params(10), 1);
end

[output.params, output.predlabel, predvalue]=calllib('CUSVM', 'SVM_Test', inputs, modelStr, params, predlabel, predvalue);

if (probability~=0)
    output.predvalue=reshape(predvalue, [col , modelStr.params(7)]);
else
    output.predvalue=reshape(predvalue, [col , modelStr.params(10)]);
end

confmatrix=zeros(modelStr.params(7), modelStr.params(7));

if (modelStr.params(2)==0)||(modelStr.params(2)==1)
      %nzero=find(diff==0);
      %stat.classrate=length(nzero)/col;
      for I=1:modelStr.params(7)
          stat.Ns(I)=length(find(data.y==modelStr.labels(I)));
      end
      for I=1:col
          idx=find(modelStr.labels==data.y(I));
          if(output.predlabel(I)==data.y(I))
              confmatrix(idx, idx)=confmatrix(idx, idx)+1;
          else
              idx1=find(modelStr.labels==output.predlabel(I));
              confmatrix(idx, idx1)=confmatrix(idx, idx1)+1;
          end
      end
      stat.classrate=0;
      for I=1:modelStr.params(7)
              stat.confmatrix(I, :)=confmatrix(I, :)/stat.Ns(I);
              stat.classrate=stat.classrate+confmatrix(I, I);
      end
      stat.classrate=stat.classrate/col;
else
      diff=output.predlabel-data.y';
      temp=diff'*diff;
      stat.MSE=temp/col;
end  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -