📄 xvalidate_svm.m
字号:
function [model,predict_label,accuracy]=xvalidate_svm(data,y,v,class);
% this function performs cross validation on a data set consisting of v
% parts. Assumes that each data sample is independent. and binary classes.
%
% Inputs:
% data: matrix of traing/testing data, each row is a feature vector;
% y: class indicator variable; assumes classes are labeled 1 through C.
% v: number of partitions for the data
% method: string giving learning method, 'percep', 'svm','mse', 'bayes'
% (assumes linear model currently)
% class: vector selecting which classes to use
%
% Outputs:
% params: learned parameters matrix, each row is a set of parameters from
% one of the validation sets
% err: error rate
%
% Algorithm:
% 1. Divide up data into v training and testing sets
% 2. Train
% 3. Calculate error rate, repeat v times.
% Add a column of ones onto the data matrix to allow for bias term
[N,nfeatures] = size(data);
data = [data];
% convert class data to -1/+1 form
y = (y==class(1)) - (y==class(2));
% remove data not in these two classes
ind= find(y~=0);
data = data(ind,:);y=y(ind);
% Initialization
[N,nfeatures] = size(data);% size of updated data matrix
N_class = max(y);%assumes classes are labeled 1 through C.
err_test = zeros(v,1);
w = zeros(nfeatures,v); % initialize weight values
% 1. Divide data into v training and testing sets
n_test = ceil(N/v);n_train = N - n_test;
for i = 1:v
test_data = data((i-1)*n_test+1:min(i*n_test,N),:);
test_y = y((i-1)*n_test+1:min(i*n_test,N));
if (i == 1)
train_data = data(i*n_test+1:N,:);
train_y = y(i*n_test+1:N);
elseif (i==v)
train_data = data(1:(v-1)*n_test,:);
train_y = y(1:(v-1)*n_test);
else
train_data = [data(1:(v-1)*n_test,:);data(i*n_test+1:N,:)];
train_y = [y(1:(v-1)*n_test);y(i*n_test+1:N)];
end;
% 2. Train the classifier
model = svmtrain(train_y,train_data,'-t 0 -b 1');
w(1,i)=sum(model.SVs(:,1).*model.sv_coef);
w(2,i)=sum(model.SVs(:,2).*model.sv_coef);
w(3,i) = - model.rho;
w(:,i)',
% % 3. Test classifier performance
[predict_label, accuracy] = svmpredict(test_y, test_data, model, '-b 1');
% print results for this data set
str =(['SVM accuracy cross validation set',num2str(i),' is ',num2str(accuracy(1))]);
disp(str);
plotxval_svm(train_data,test_data,train_y,test_y,w(:,i),str,model.SVs);
end;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -