⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 xvalidate_svm.m

📁 支持向量机应用的实例,希望对大家有用!支持向量机应用的实例,希望对大家有用!
💻 M
字号:
function [model,predict_label,accuracy]=xvalidate_svm(data,y,v,class);
% this function performs cross validation on a data set consisting of v
% parts. Assumes that each data sample is independent. and binary classes.
%
% Inputs:
%   data: matrix of traing/testing data, each row is a feature vector;
%   y: class indicator variable; assumes classes are labeled 1 through C.
%   v: number of partitions for the data
%   method: string giving learning method, 'percep', 'svm','mse', 'bayes'
%   (assumes linear model currently)
%   class: vector selecting which classes to use
%
%  Outputs:
%   params: learned parameters matrix, each row is a set of parameters from
%   one of the validation sets
%   err: error rate
%
%   Algorithm:
%   1. Divide up data into v training and testing sets
%   2. Train
%   3. Calculate error rate, repeat v times.
% Add a column of ones onto the data matrix to allow for bias term
[N,nfeatures] = size(data);
data = [data];

% convert class data to -1/+1 form
y = (y==class(1)) - (y==class(2));
% remove data not in these two classes
ind= find(y~=0);
data = data(ind,:);y=y(ind);

% Initialization
[N,nfeatures] = size(data);% size of updated data matrix
N_class = max(y);%assumes classes are labeled 1 through C.
err_test = zeros(v,1); 
w = zeros(nfeatures,v); % initialize weight values


% 1. Divide data into v training and testing sets
n_test = ceil(N/v);n_train = N - n_test;
for i = 1:v
    test_data = data((i-1)*n_test+1:min(i*n_test,N),:);
    test_y  = y((i-1)*n_test+1:min(i*n_test,N));
    if (i == 1)
       train_data = data(i*n_test+1:N,:);
       train_y  = y(i*n_test+1:N);
   elseif (i==v)
       train_data = data(1:(v-1)*n_test,:);
       train_y  = y(1:(v-1)*n_test);
   else
       train_data = [data(1:(v-1)*n_test,:);data(i*n_test+1:N,:)];
       train_y  = [y(1:(v-1)*n_test);y(i*n_test+1:N)];
   end;
   
   % 2. Train the classifier
   model = svmtrain(train_y,train_data,'-t 0 -b 1');
   w(1,i)=sum(model.SVs(:,1).*model.sv_coef);
   w(2,i)=sum(model.SVs(:,2).*model.sv_coef);
   w(3,i) = - model.rho;
   w(:,i)',
   %   % 3. Test classifier performance
   [predict_label, accuracy] = svmpredict(test_y, test_data, model, '-b 1');

   
   % print results for this data set
    str =(['SVM accuracy cross validation set',num2str(i),' is ',num2str(accuracy(1))]);
    disp(str);
    plotxval_svm(train_data,test_data,train_y,test_y,w(:,i),str,model.SVs);
end;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -