⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 demossl.m

📁 一个Matlab写的关于图理论以及其在机器学习中应用的教学用GUI软件
💻 M
📖 第 1 页 / 共 3 页
字号:
end


% --- Executes on button press in PshNewLabels.
function PshNewLabels_Callback(hObject, eventdata, handles)
% hObject    handle to PshNewLabels (see GCBO)
% eventdata  reserved - to be defined in a future version of MATLAB
% handles    structure with handles and user data (see GUIDATA)
global ALLDATA_SSL;

ALLDATA_SSL.Labeled = randperm(ALLDATA_SSL.Num);
ALLDATA_SSL.Labeled = ALLDATA_SSL.Labeled(1:ALLDATA_SSL.NumLabels);

UpdateALL(handles);

% --- Executes on button press in PshUpdateALL.
function PshUpdateALL_Callback(hObject, eventdata, handles)
% hObject    handle to PshUpdateALL (see GCBO)
% eventdata  reserved - to be defined in a future version of MATLAB
% handles    structure with handles and user data (see GUIDATA)

UpdateALL(handles);


%--------------------------------------------------------------------------
function DrawData(Fig)
global ALLDATA_SSL;

if(ALLDATA_SSL.Density==6)
    cla(Fig);
end

if(ALLDATA_SSL.Density~=6)
  num_classes=max(ALLDATA_SSL.y);
  colors = {'red','blue','black'};
  hold(Fig,'on'); 
  % unlabeled data
  for i=1:num_classes
   plot(Fig,ALLDATA_SSL.x(1,ALLDATA_SSL.y==i),ALLDATA_SSL.x(2,ALLDATA_SSL.y==i),'MarkerEdgeColor',colors{i},'Marker','.','LineStyle','none')
  end

  % labeled data
  LabelVector =zeros(ALLDATA_SSL.Num,1);
  LabelVector(ALLDATA_SSL.Labeled)=ALLDATA_SSL.y(ALLDATA_SSL.Labeled);

  for i=1:num_classes
   plot(Fig,ALLDATA_SSL.x(1,LabelVector==i),ALLDATA_SSL.x(2,LabelVector==i),'MarkerFaceColor',colors{i},'MarkerEdgeColor',colors{3},'Marker','o','LineStyle','none','MarkerSize',8)
  end
  hold(Fig,'off'); 
  axis(Fig,'equal');
end

%--------------------------------------------------------------------------
function ShowOutput(Fig,output)
global ALLDATA_SSL;
global SSLDATA;

cla(Fig);

if(ALLDATA_SSL.Density~=6)
  num_classes=max(ALLDATA_SSL.y);
  colors = {'red','blue','black'};
  hold(Fig,'on'); 
  % output of the classifier
  for i=1:num_classes
   plot(Fig,ALLDATA_SSL.x(1,SSLDATA.Output==i),ALLDATA_SSL.x(2,SSLDATA.Output==i),'MarkerEdgeColor',colors{i},'Marker','.','LineStyle','none')
  end
  % points which have no label
  plot(Fig,ALLDATA_SSL.x(1,SSLDATA.Output==0),ALLDATA_SSL.x(2,SSLDATA.Output==0),'MarkerEdgeColor','magenta','Marker','.','LineStyle','none')

  % labeled data  
  LabelVector =zeros(ALLDATA_SSL.Num,1);
  LabelVector(ALLDATA_SSL.Labeled)=ALLDATA_SSL.y(ALLDATA_SSL.Labeled);

  for i=1:num_classes
   plot(Fig,ALLDATA_SSL.x(1,LabelVector==i),ALLDATA_SSL.x(2,LabelVector==i),'MarkerFaceColor',colors{i},'MarkerEdgeColor',colors{3},'Marker','o','LineStyle','none','MarkerSize',8)
  end
  hold(Fig,'off'); 
  axis(Fig,'equal');
end




%--------------------------------------------------------------------------
function BuildWeights()
global ALLDATA_SSL;

dist2 = DistEuclideanPiotrDollar(ALLDATA_SSL.x',ALLDATA_SSL.x'); % squared distances
if(ALLDATA_SSL.GraphType<2)
  if(ALLDATA_SSL.Density<6)
    [SD,IX]=sort(dist2,2);
    KNN     = IX(:,2:ALLDATA_SSL.NumKNN+1)';
    KNNDist = SD(:,2:ALLDATA_SSL.NumKNN+1)';
  else
    load GD_USPSKNN;
    KNN    =KNN(1:ALLDATA_SSL.NumKNN,:);
    KNNDist=KNNDist(1:ALLDATA_SSL.NumKNN,:);
  end
  % get kNN weight matrix
  ALLDATA_SSL.K = sparse(ALLDATA_SSL.Num,ALLDATA_SSL.Num);
  for i=1:ALLDATA_SSL.Num
    ALLDATA_SSL.K(KNN(:,i),i)=exp(-1/(2*ALLDATA_SSL.Gamma^2)*KNNDist(:,i));
  end
  % note that K is not symmetric yet , now we symmetrize K 
  if(ALLDATA_SSL.GraphType==1) ALLDATA_SSL.K=(ALLDATA_SSL.K+ALLDATA_SSL.K')+abs(ALLDATA_SSL.K-ALLDATA_SSL.K'); ALLDATA_SSL.K=0.5*ALLDATA_SSL.K; end
  if(ALLDATA_SSL.GraphType==0) ALLDATA_SSL.K=(ALLDATA_SSL.K+ALLDATA_SSL.K')-abs(ALLDATA_SSL.K-ALLDATA_SSL.K'); ALLDATA_SSL.K=0.5*ALLDATA_SSL.K; end 
else
  ALLDATA_SSL.K = exp(-1/(2*ALLDATA_SSL.Gamma^2)*dist2).*(dist2 < ALLDATA_SSL.Eps^2 & dist2~=0);
end

% if(ALLDATA_SSL.GraphType<2)
%   if(ALLDATA_SSL.Density<6)
%     [KNN,KNNDist]=getKNN(ALLDATA_SSL.x,ALLDATA_SSL.NumKNN);
%   else
%     load USPSKNN;
%     KNN    =KNN(1:ALLDATA_SSL.NumKNN,:);
%     KNNDist=KNNDist(1:ALLDATA_SSL.NumKNN,:);
%   end
%   % get kNN weight matrix
%   ALLDATA_SSL.K = getSparseWeightMatrixFromKNN(KNNDist,KNN,1/(2*ALLDATA_SSL.Gamma^2),ALLDATA_SSL.GraphType,1); 
%   % note that K is not symmetric yet (in the case of the symmetric KNN), now we symmetrize K 
%   if(ALLDATA_SSL.GraphType==1) ALLDATA_SSL.K=(ALLDATA_SSL.K+ALLDATA_SSL.K')+abs(ALLDATA_SSL.K-ALLDATA_SSL.K'); ALLDATA_SSL.K=0.5*ALLDATA_SSL.K; end
% else
%   dist=getDistanceMatrix(ALLDATA_SSL.x);
%   ALLDATA_SSL.K = exp(-1/(2*ALLDATA_SSL.Gamma^2)*dist.^2).*(dist < ALLDATA_SSL.Eps & dist~=0);
% end

%--------------------------------------------------------------------------
function DrawWeights(Fig)
global ALLDATA_SSL

cla(Fig);

if(ALLDATA_SSL.Density~=6)
  hold(Fig,'on');
  STEP=1; NNZK=nnz(ALLDATA_SSL.K);
%   if(NNZK>60000)
%    STEP=ceil(NNZK/60000); display(['WARNING: TOO MANY EDGES, WILL ONLY SHOW 1/',num2str(STEP),' of all points']);
%   end
  xx=zeros(2*ALLDATA_SSL.Num,1);
  yy=zeros(2*ALLDATA_SSL.Num,1);
  tic
  for i=1:STEP:ALLDATA_SSL.Num
    indices = find(ALLDATA_SSL.K(i+1:end,i)>0)+i;
    NumIndices=length(indices);
    xx(1:2:2*NumIndices)=ALLDATA_SSL.x(1,i);
    xx(2:2:2*NumIndices)=ALLDATA_SSL.x(1,indices);
    yy(1:2:2*NumIndices)=ALLDATA_SSL.x(2,i);
    yy(2:2:2*NumIndices)=ALLDATA_SSL.x(2,indices);
    plot(Fig,xx(1:2*NumIndices),yy(1:2*NumIndices),'-g');
  end
  t1=toc
%   STEP=1;
%   hold(Fig2,'on');
%   tic
%   for i=1:STEP:ALLDATA_SSL.Num
%     indices = find(ALLDATA_SSL.K(i+1:end,i)>0)+i;
%     NumIndices=length(indices);
%     %plot(Fig,[repmat(ALLDATA_SSL.x(1,i),1,NumIndices); ALLDATA_SSL.x(1,indices)], [repmat(ALLDATA_SSL.x(2,i),1,NumIndices);ALLDATA_SSL.x(2,indices)], '-g');
%     if(NumIndices>0)
%      plot(Fig2,[repmat(ALLDATA_SSL.x(1,i),1,NumIndices); ALLDATA_SSL.x(1,indices)], [repmat(ALLDATA_SSL.x(2,i),1,NumIndices);ALLDATA_SSL.x(2,indices)], '-g');
%      %plot(Fig,[repmat(ALLDATA_SSL.x(1,i),NumIndices,1), ALLDATA_SSL.x(1,indices)], [repmat(ALLDATA_SSL.x(2,i),NumIndices,1), ALLDATA_SSL.x(2,indices)], '-g');
%     end
%   end
%   t2=toc
%    hold(Fig2,'off');
  hold(Fig,'off');
  axis(Fig,'equal');
  DrawData(Fig);
end

%--------------------------------------------------------------------------
function numComps = CompComps()
global ALLDATA_SSL
compvec = GD_GetComps(ALLDATA_SSL.K);
numComps = max(compvec);

% disp(['Number of connected components: ', num2str(max(compvec))]);
% numPointsInComps = zeros(numComps,1);
%     for i=1:numComps
%       numPointsInComps(i)=nnz(compvec==i);
%     end      
% numPointsInComps=sort(numPointsInComps,'descend');
% disp(['Number of points in the ',num2str(min(numComps,10)),' largest components']);
% for i=1:min(numComps,10)
%   disp(['Cluster ',num2str(i),' with ',num2str(numPointsInComps(i)),' points']);
% end

%--------------------------------------------------------------------------
function [EdInBet,EdBet,WBet,WInBet] = getEdgeStatistics()
% EdInBet : Edges inbetween classes
% EdBet   : Edges between classes
% WBet    : Weights between classes
% WInBet  : weights inbetween classes
global ALLDATA_SSL

num_classes=max(ALLDATA_SSL.y);
WInBet =0; EdInBet=0;
for i=1:num_classes
  WInBet = WInBet    + (ALLDATA_SSL.y==i)'*ALLDATA_SSL.K*(ALLDATA_SSL.y==i);
  EdInBet= EdInBet   + (ALLDATA_SSL.y==i)'*double(ALLDATA_SSL.K>0)*(ALLDATA_SSL.y==i);
end

WBet = 0.5*(sum(sum(ALLDATA_SSL.K)) - WInBet);
EdBet= 0.5*(nnz(ALLDATA_SSL.K) - EdInBet);
WInBet = 0.5*WInBet;
EdInBet= 0.5*EdInBet;

%--------------------------------------------------------------------------
function UpdateSSL(handles)
global ALLDATA_SSL
global SSLDATA
    
laplacian=2;
lambda=0;
LabelVector = zeros(ALLDATA_SSL.Num,1);
LabelVector(ALLDATA_SSL.Labeled)=ALLDATA_SSL.y(ALLDATA_SSL.Labeled);
[output,d]=GD_PerformSSL(ALLDATA_SSL.K, LabelVector, max(ALLDATA_SSL.y), laplacian, lambda, SSLDATA.Regul);
[SSLDATA.TestError, SSLDATA.TrainError,SSLDATA.NotLabeled,SSLDATA.Output]=GD_EvalSolution(ALLDATA_SSL.y,output,ALLDATA_SSL.Labeled);
ShowOutput(handles.axes3);

set(handles.TxtTestError,'String',[num2str(100*SSLDATA.TestError,'%2.1f'),'%']);
set(handles.TxtTrainError,'String',[num2str(100*SSLDATA.TrainError,'%2.1f'),'%']);
set(handles.TxtNotLabeled,'String',[num2str(100*SSLDATA.NotLabeled,'%2.1f'),'%']);

% set controls to the initial values
%set(handles.SldNumNeighbors,'Value',(ALLDATA_SSL.NumKNN-ALLDATA_SSL.MinKNN)/(ALLDATA_SSL.MaxKNN-ALLDATA_SSL.MinKNN));
set(handles.SldNumLabels,'Value',(ALLDATA_SSL.NumLabels-ALLDATA_SSL.MinLabels)/(ALLDATA_SSL.MaxLabels-ALLDATA_SSL.MinLabels));
set(handles.SldDim,'Value',(ALLDATA_SSL.Dim-ALLDATA_SSL.MinDim)/(ALLDATA_SSL.MaxDim-ALLDATA_SSL.MinDim));
% set all static text elements
% if(ALLDATA_SSL.GraphType<2)
%  set(handles.TxtNumNeighbors,'String',num2str(ALLDATA_SSL.NumKNN));
% else
%  set(handles.TxtNumNeighbors,'String',num2str(ALLDATA_SSL.Eps,'%2.2f'));
% end
set(handles.TxtNumLabels,'String',num2str(ALLDATA_SSL.NumLabels));
set(handles.TxtDim,'String',num2str(ALLDATA_SSL.Dim));


%--------------------------------------------------------------------------
function UpdateALL(handles)
global ALLDATA_SSL
global SSLDATA

% clear all graphs
cla(handles.axes2);
cla(handles.axes3);
cla(handles.axes4);


% draw dataset
DrawData(handles.axes4);

% build and draw graph
BuildWeights();
DrawWeights(handles.axes2);
drawnow;

% get and set components
numComps = CompComps();
set(handles.TxtNumComps,'String',num2str(numComps));

[EdInBet,EdBet,WBet,WInBet]=getEdgeStatistics();
% set all static text elements for the graph Statistics
set(handles.TxtWBet,'String',num2str(WBet,'%2.2f'));
set(handles.TxtWInBet,'String',num2str(WInBet,'%2.2f'));
set(handles.TxtEdgeBet,'String',num2str(EdBet,'%2.0d'));
set(handles.TxtEdgeInBet,'String',num2str(EdInBet,'%2.0d'));

if(ALLDATA_SSL.Density<6)
 set(handles.TxtTotalNumberPoints,'String',['Total number of points: ',num2str(ALLDATA_SSL.Num)]);
else
 set(handles.TxtTotalNumberPoints,'String',['Total number of points: ',num2str(ALLDATA_SSL.Num),', ONLY KNN POSSIBLE !']);
end


% set regularization parameter
set(handles.TxtRegul,'String',num2str(SSLDATA.Regul,'%2.2e'));

UpdateSSL(handles);









⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -