⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gd_dossl.m

📁 一个Matlab写的关于图理论以及其在机器学习中应用的教学用GUI软件
💻 M
字号:
function GD_DoSSL2(dataset, s, K_Diff, iteration, damping, timestep, GraphTypeDiff, GraphTypeClass, K_Class, gammaVal, runs);

rand('state',0)
%dataset=1
%K_Diff = 100;
%K_Class = 100; 
%s = 2;
%iter = 1;
%GraphType_Diff=1;
%GraphType_Class=1;
%runs = 100;
num_classes=3;
lambda = 0.0;
%gammaVal = 1.5;

MuRegul=0.2;

if(dataset<10 || dataset>12)
  num_classes=3;
  y = [ ones(333,1); 2*ones(250,1); 3*ones(417,1)];
  %y = [ ones(33,1); 2*ones(25,1); 3*ones(42,1)];
  %y = [y; y];
  [X,y] = Manifold_Denoising3(dataset,s,K_Diff,iteration,0,1,GraphTypeDiff);
  num = size(X,2);
  num_classes=y(num);
  %y = [ ones(num/2,1); 2*ones(num/2,1)];
  %y = [ones(3,1); 2*ones(3,1); 3*ones(4,1)];
  % y=ones(floor(num/num_classes),1);
  % for i=2:num_classes-1
  %     y = [ y; i*ones(floor(num/num_classes),1)];
  % end
  % y = [y; num_classes*ones(num-size(y,1),1)];
else
  if(dataset==11)
    num_classes=10;
    y = [ ones(6903,1); 2*ones(7877,1); 3*ones(6990,1);  4*ones(7141,1);  5*ones(6824,1);   6*ones(6313,1); 7*ones(6876,1);  8*ones(7293,1);  9*ones(6825,1);  10*ones(6958,1);];
    num=size(y,1);
    dim=28*28;        
    string=['RESULT_data=',num2str(dataset),'_s=',num2str(s),'_KNN=',num2str(K_Diff),'_iter=',num2str(iteration),'_damp=',num2str(damping),'_timestep=',num2str(timestep),'_GraphType=',num2str(GraphTypeDiff),'.BIN']
    %fid=fopen(string);
    %[X, c]=fread(fid,inf,'float32');
    %fclose(fid);
    %num=size(X,1)/(dim)
    %X=reshape(X,dim,num);
  end
end

% transforming the multiclass labels into binary (-1,1) one-versus-all
% label set
Labels = zeros(num,num_classes);
for i=1:num_classes
  Labels(:,i) = 2*(y==i) - 1;
end

% computing the k nearest neighbors
tic
[KNN, KNNDist] = getKNN(X,K_Class);
t=toc;
disp(['Time for computation of k nearest neighbors: ', num2str(t)]);

% possible values for the number of labeled points
num_labeled_pos = num_classes*[1,2,5,10,20];
%num_labeled_pos = [3, 6, 12, 24, 48, 96, 192, 384];
%num_labeled_pos = [10, 20, 50, 100, 500, 1000, 5000, 20000, 50000];

% array for the errors and number of not classified points
error         = zeros(size(num_labeled_pos,2),runs);
not_classified =zeros(size(num_labeled_pos,2),runs);

for tt=1:size(num_labeled_pos,2)
  num_labeled = num_labeled_pos(tt); 
  
  disp(['']);
  disp([num2str(runs),' runs of dataset ',num2str(dataset)]);
  disp([num2str(iteration),' Iterations with graph type ', num2str(GraphTypeDiff),' and K =',num2str(K_Diff)]);
  disp(['Classification with graph type ',num2str(GraphTypeClass),' and K =',num2str(K_Class)]);
  disp(['Number of points: ', num2str(num),', Number of labeled points: ', num2str(num_labeled)]);
 
  % select randomly labels from the labeled set
  tic
  trainLabels = zeros(num,num_classes*runs);
  trainLabels = sparse(trainLabels);
  for i=1:runs
    %disp(['Run: ',num2str(i)]);
    % select num_labeled Labels from the Labels and ensure that there is
    % one from each class
    not_accept = 1;
    while not_accept
      indices = randperm(num);
     
      trainLabels(indices(1:num_labeled),(i-1)*num_classes+1:i*num_classes)=Labels(indices(1:num_labeled),:);
      not_accept=0;
      for j=1:num_classes
          if(sum(trainLabels(:,(i-1)*num_classes+j)==1)==0) not_accept=1; end
      end
      if(not_accept)
          trainLabels(:,(i-1)*num_classes+1:i*num_classes)=zeros(num,num_classes);
      end
    end
  end
  t=toc
  disp(['Time for generation of labels: ', num2str(t)]);
  
  % get classifier output (simultaneously for all runs) for all the labels
  tic
  output=zeros(num,num_classes*runs);
  [output,d]=getSSLClassificationFromKNN(KNNDist, KNN, trainLabels, 0, 1.0, lambda, gammaVal, GraphTypeClass,MuRegul);
  mean(d)
  t=toc
  disp(['Time for Classification: ', num2str(t)]);


  % compute final classification, final==0 corresponds to "not classified
  % -> vertex is disconnected from labeled components
  table = repmat(0:num_classes,num,1)';
  for i=1:runs    
    % augmentation is done to include a zero label which corresponds to not
    % classified
    augment = [zeros(num,1), output(:,(i-1)*num_classes+1:i*num_classes)];
    final = table( min( augment == repmat(max(augment(:,2:num_classes+1),[],2),1,num_classes+1), [ones(num,1),augment(:,2:num_classes+1)] ~=0)');
    error(tt,i) = mean(final~=y);
    not_classified(tt,i) = mean(final==0);
    %table2=1:num;
    %table2(final==0);
  end
  map=[ [0,0,1];[0,1,0]; [0,0,0]; [1,0,0]; [1,1,0] ; [0,1,1]; [1,0,1] ];
  figure
  hold on      
  scatter(X(1,:),X(2,:),5,sign(output(:,1)),'filled');	  
  colormap(map);
  grid on
  view(2)
  hold off
      %colormap(hsv(128));
  [error(tt,:)', not_classified(tt,:)'];
  mean_error(tt) = mean(error(tt,:));
  disp(['Mean error: ', num2str(mean_error(tt))]);
  mean_not_classified(tt) = mean(not_classified(tt,:));
  disp(['Mean of not classified points: ', num2str(mean_not_classified(tt))]);
  std_error(tt)  = sqrt(var(error(tt,:)));
  disp(['Standard deviation of the error: ', num2str(std_error(tt))]);
end

filename = ['SSL_data=',num2str(dataset),'_s=',num2str(s),'_KNN=',num2str(K_Diff),'_iter=',num2str(iteration),'_damp=',num2str(damping),'_timestep=',num2str(timestep),'_GraphTypeDiff=',num2str(GraphTypeDiff),'_GraphTypeClass=',num2str(GraphTypeClass),'_KCLass=',num2str(K_Class),'_Gamma=',num2str(gammaVal)];

save(filename,'num_labeled_pos', 'error', 'not_classified', 'mean_error', 'mean_not_classified', 'std_error')

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -