datatraintest.m

来自「这是一个支持向量机的工具」· M 代码 · 共 71 行

M
71
字号
function [ntrain,ntest]=datatraintest(datas,randoms,kfold,bi,ks)

if randoms
% find ntrain training data set, the rest is as test set
for i=1:kfold
[nsample,n]=size(datas);
nt=round(nsample/kfold);
na=nsample-nt;
n3=rand(1,nsample);
nb=1/kfold;
n1=find(n3>nb);      % training data set
n2=find(n3<=nb);     % testing data set
n3=length(n1);
if n3<na
 n1=[n1 n2(1:(na-n3))];
 n2=n2((na-n3)+1:length(n2));
elseif n3>na
 n2=[n1(na+1:n3) n2];
 n1=n1(1:na);
end

ntrain{i}=n1;
ntest{i}=n2;

if 1
  dtrain{i}=datas(n1,:);    % training data set
  dtest{i}=datas(n2,:);     % test/validation data set
else                        % consider discrete cases
k3=setdiff(1:n,bi);
for j=1:length(ks)
  k=ks(j);
  Q=find(S(:,bi)==k);
  Xc=S(Q,:);              % training data set
  X=Xc;
  Q=find(T(:,bi)==k);
  x=T(Q,:);               % test/validation data set
dtest{j}=x(:,k3);
dtrain{j}=Xc(:,k3);
end

end % end if
end % end i

% data are splitted for k-fold cross validation
else
    
nsample=size(datas,1);
neach=fix(nsample/kfold);
for ncut=1:kfold
  ni(ncut,1:2)=[neach*(ncut-1)+1 neach*ncut];
  if ncut==kfold ni(ncut,2)=nsample; end
end

for ncut=1:kfold
 if kfold>1
  if ncut==1
    ntrain{ncut}=[ni(ncut+1,1):nsample];
  elseif ncut==kfold
    ntrain{ncut}=[1:ni(ncut-1,2)];
  else
    ntrain{ncut}=[1:ni(ncut-1,2) ni(ncut+1,1):nsample];
  end
 else
  ntrain{ncut}=[ni(ncut,1):ni(ncut,2)];
 end

  ntest{ncut}=[ni(ncut,1):ni(ncut,2)];
end

end  

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?