📄 datatraintest.m
字号:
function [ntrain,ntest]=datatraintest(datas,randoms,kfold,bi,ks)
if randoms
% find ntrain training data set, the rest is as test set
for i=1:kfold
[nsample,n]=size(datas);
nt=round(nsample/kfold);
na=nsample-nt;
n3=rand(1,nsample);
nb=1/kfold;
n1=find(n3>nb); % training data set
n2=find(n3<=nb); % testing data set
n3=length(n1);
if n3<na
n1=[n1 n2(1:(na-n3))];
n2=n2((na-n3)+1:length(n2));
elseif n3>na
n2=[n1(na+1:n3) n2];
n1=n1(1:na);
end
ntrain{i}=n1;
ntest{i}=n2;
if 1
dtrain{i}=datas(n1,:); % training data set
dtest{i}=datas(n2,:); % test/validation data set
else % consider discrete cases
k3=setdiff(1:n,bi);
for j=1:length(ks)
k=ks(j);
Q=find(S(:,bi)==k);
Xc=S(Q,:); % training data set
X=Xc;
Q=find(T(:,bi)==k);
x=T(Q,:); % test/validation data set
dtest{j}=x(:,k3);
dtrain{j}=Xc(:,k3);
end
end % end if
end % end i
% data are splitted for k-fold cross validation
else
nsample=size(datas,1);
neach=fix(nsample/kfold);
for ncut=1:kfold
ni(ncut,1:2)=[neach*(ncut-1)+1 neach*ncut];
if ncut==kfold ni(ncut,2)=nsample; end
end
for ncut=1:kfold
if kfold>1
if ncut==1
ntrain{ncut}=[ni(ncut+1,1):nsample];
elseif ncut==kfold
ntrain{ncut}=[1:ni(ncut-1,2)];
else
ntrain{ncut}=[1:ni(ncut-1,2) ni(ncut+1,1):nsample];
end
else
ntrain{ncut}=[ni(ncut,1):ni(ncut,2)];
end
ntest{ncut}=[ni(ncut,1):ni(ncut,2)];
end
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -