⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 train_test.asv

📁 the matlab code of ginisvm - an svm implementation good at probability regression.
💻 ASV
字号:
%----------------------------------------% First Read in the input data% trainx -> input data% Ytrain -> input label% crossx -> cross-validation data% Ycross -> cross-validation label% The training and test labels should be binary inputs% or prior probabilities.% An example for a three class label is [0 0 1] to indicate% that the training label belongs to class 3. Or it could% be [0.1 0.3 0.6] to indicate prior confidence.% KTYPE and KSCALE are the kernel parameters. For example% KTYPE = 6 is for gaussian and KSCALE is the sigma parameter% of the gaussian.% Cgini is the C parameter in SVM% B is the gamma parameter (see JMLR paper).% Niter is the number of iterations of the randomized SMO% algorithm. %---------------------------------------[N,D] = size(trainx);[Ny,M] = size(Ytrain);if Ny ~= N,   error('Training Data size neq labels');end;[Ncross,D] = size(crossx);[Nycross,M] = size(Ycross);if Nycross ~= Ncross,   error('Cross-validation Data size neq labels');end;global KTYPE;global KSCALE;%--------------------------------% Parameters to tune%--------------------------------B = 0.01;         % gamma parameterNiter = 100000;   % For better convergenceCgini = 1;       % To prevent overfitting%Cgini = 2.471/B;       % To prevent overfittingKTYPE = 1;       % Kernel type KSCALE = 10000;      % Parameter%--------------------------------plotflag = 1;for i = 1:N,   yindex(i) = find(Ytrain(i,:) > 0.5);end;for i = 1:Ncross,   ycrossindex(i) = find(Ycross(i,:) > 0.5);end;%---------------------------------------% Train the ginisvm%---------------------------------------fprintf('Starting SVM Training....');[testcoeff,testbias] = ginitrain(trainx,Ytrain,Cgini*ones(N,1),Niter,B*ones(N,1));fprintf('....Done\n');%---------------------------------------% Compute the sparsity index.%---------------------------------------spind = find(sum(abs(testcoeff),2) < 1e-5);nsv = length(spind);for k = 1:nsv,   testcoeff(spind(k),:) = zeros(1,M);end;fprintf('Sparsity Index = %d\n',nsv/N*100);   %---------------------------------------% Performance on training set%---------------------------------------fprintf('Evaluating Performance on Training set\n');errordist = zeros(1,M);error = 0;eflag = zeros(N,1);for k = 1:N,   mvalue = kernel(trainx(k,:),trainx)*testcoeff+testbias;   [result, resultmargin] = ginitest(mvalue,B);   [maxval,ind] = max(result);   if ind ~= yindex(k),      error = error + 1;      eflag(k) = 1;      errordist(yindex(k)) = errordist(yindex(k)) + 1;   end;end;fprintf('Multi-class Train Error = %d percent \n',(error/N)*100);for i = 1:M,   fprintf('Class %d Error = %d percent \n',i,(errordist(i)/N)*100);end;clear result resultmargin;%---------------------------------------% Performance on test set%---------------------------------------fprintf('Evaluating Performance on Cross-validation set\n');errordist = zeros(1,M);error = 0;eflagcross = zeros(Ncross,1);for k = 1:Ncross,   mvalue = kernel(crossx(k,:),trainx)*testcoeff+testbias;   [result, resultmargin] = ginitest(mvalue,B);   [maxval,ind] = max(result);   if ind ~= ycrossindex(k),      error = error + 1;      eflagcross(k) = 1;      errordist(ycrossindex(k)) = errordist(ycrossindex(k)) + 1;   end;end;fprintf('Multi-class Cross-validation Error = %d percent \n',(error/Ncross)*100);for i = 1:M,   fprintf('Class %d Error = %d percent \n',i,(errordist(i)/Ncross)*100);end;clear result resultmargin;%---------------------------------------% Plot the probability Contour%---------------------------------------if plotflag == 1,   fprintf('Plotting Contour ....');   figure;   giniplot(trainx,Ytrain,testcoeff,testbias',B);   fprintf('....done\n');end;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -