📄 demossl.m
字号:
end
% --- Executes on button press in PshNewLabels.
function PshNewLabels_Callback(hObject, eventdata, handles)
% hObject handle to PshNewLabels (see GCBO)
% eventdata reserved - to be defined in a future version of MATLAB
% handles structure with handles and user data (see GUIDATA)
global ALLDATA_SSL;
ALLDATA_SSL.Labeled = randperm(ALLDATA_SSL.Num);
ALLDATA_SSL.Labeled = ALLDATA_SSL.Labeled(1:ALLDATA_SSL.NumLabels);
UpdateALL(handles);
% --- Executes on button press in PshUpdateALL.
function PshUpdateALL_Callback(hObject, eventdata, handles)
% hObject handle to PshUpdateALL (see GCBO)
% eventdata reserved - to be defined in a future version of MATLAB
% handles structure with handles and user data (see GUIDATA)
UpdateALL(handles);
%--------------------------------------------------------------------------
function DrawData(Fig)
global ALLDATA_SSL;
if(ALLDATA_SSL.Density==6)
cla(Fig);
end
if(ALLDATA_SSL.Density~=6)
num_classes=max(ALLDATA_SSL.y);
colors = {'red','blue','black'};
hold(Fig,'on');
% unlabeled data
for i=1:num_classes
plot(Fig,ALLDATA_SSL.x(1,ALLDATA_SSL.y==i),ALLDATA_SSL.x(2,ALLDATA_SSL.y==i),'MarkerEdgeColor',colors{i},'Marker','.','LineStyle','none')
end
% labeled data
LabelVector =zeros(ALLDATA_SSL.Num,1);
LabelVector(ALLDATA_SSL.Labeled)=ALLDATA_SSL.y(ALLDATA_SSL.Labeled);
for i=1:num_classes
plot(Fig,ALLDATA_SSL.x(1,LabelVector==i),ALLDATA_SSL.x(2,LabelVector==i),'MarkerFaceColor',colors{i},'MarkerEdgeColor',colors{3},'Marker','o','LineStyle','none','MarkerSize',8)
end
hold(Fig,'off');
axis(Fig,'equal');
end
%--------------------------------------------------------------------------
function ShowOutput(Fig,output)
global ALLDATA_SSL;
global SSLDATA;
cla(Fig);
if(ALLDATA_SSL.Density~=6)
num_classes=max(ALLDATA_SSL.y);
colors = {'red','blue','black'};
hold(Fig,'on');
% output of the classifier
for i=1:num_classes
plot(Fig,ALLDATA_SSL.x(1,SSLDATA.Output==i),ALLDATA_SSL.x(2,SSLDATA.Output==i),'MarkerEdgeColor',colors{i},'Marker','.','LineStyle','none')
end
% points which have no label
plot(Fig,ALLDATA_SSL.x(1,SSLDATA.Output==0),ALLDATA_SSL.x(2,SSLDATA.Output==0),'MarkerEdgeColor','magenta','Marker','.','LineStyle','none')
% labeled data
LabelVector =zeros(ALLDATA_SSL.Num,1);
LabelVector(ALLDATA_SSL.Labeled)=ALLDATA_SSL.y(ALLDATA_SSL.Labeled);
for i=1:num_classes
plot(Fig,ALLDATA_SSL.x(1,LabelVector==i),ALLDATA_SSL.x(2,LabelVector==i),'MarkerFaceColor',colors{i},'MarkerEdgeColor',colors{3},'Marker','o','LineStyle','none','MarkerSize',8)
end
hold(Fig,'off');
axis(Fig,'equal');
end
%--------------------------------------------------------------------------
function BuildWeights()
global ALLDATA_SSL;
dist2 = DistEuclideanPiotrDollar(ALLDATA_SSL.x',ALLDATA_SSL.x'); % squared distances
if(ALLDATA_SSL.GraphType<2)
if(ALLDATA_SSL.Density<6)
[SD,IX]=sort(dist2,2);
KNN = IX(:,2:ALLDATA_SSL.NumKNN+1)';
KNNDist = SD(:,2:ALLDATA_SSL.NumKNN+1)';
else
load GD_USPSKNN;
KNN =KNN(1:ALLDATA_SSL.NumKNN,:);
KNNDist=KNNDist(1:ALLDATA_SSL.NumKNN,:);
end
% get kNN weight matrix
ALLDATA_SSL.K = sparse(ALLDATA_SSL.Num,ALLDATA_SSL.Num);
for i=1:ALLDATA_SSL.Num
ALLDATA_SSL.K(KNN(:,i),i)=exp(-1/(2*ALLDATA_SSL.Gamma^2)*KNNDist(:,i));
end
% note that K is not symmetric yet , now we symmetrize K
if(ALLDATA_SSL.GraphType==1) ALLDATA_SSL.K=(ALLDATA_SSL.K+ALLDATA_SSL.K')+abs(ALLDATA_SSL.K-ALLDATA_SSL.K'); ALLDATA_SSL.K=0.5*ALLDATA_SSL.K; end
if(ALLDATA_SSL.GraphType==0) ALLDATA_SSL.K=(ALLDATA_SSL.K+ALLDATA_SSL.K')-abs(ALLDATA_SSL.K-ALLDATA_SSL.K'); ALLDATA_SSL.K=0.5*ALLDATA_SSL.K; end
else
ALLDATA_SSL.K = exp(-1/(2*ALLDATA_SSL.Gamma^2)*dist2).*(dist2 < ALLDATA_SSL.Eps^2 & dist2~=0);
end
% if(ALLDATA_SSL.GraphType<2)
% if(ALLDATA_SSL.Density<6)
% [KNN,KNNDist]=getKNN(ALLDATA_SSL.x,ALLDATA_SSL.NumKNN);
% else
% load USPSKNN;
% KNN =KNN(1:ALLDATA_SSL.NumKNN,:);
% KNNDist=KNNDist(1:ALLDATA_SSL.NumKNN,:);
% end
% % get kNN weight matrix
% ALLDATA_SSL.K = getSparseWeightMatrixFromKNN(KNNDist,KNN,1/(2*ALLDATA_SSL.Gamma^2),ALLDATA_SSL.GraphType,1);
% % note that K is not symmetric yet (in the case of the symmetric KNN), now we symmetrize K
% if(ALLDATA_SSL.GraphType==1) ALLDATA_SSL.K=(ALLDATA_SSL.K+ALLDATA_SSL.K')+abs(ALLDATA_SSL.K-ALLDATA_SSL.K'); ALLDATA_SSL.K=0.5*ALLDATA_SSL.K; end
% else
% dist=getDistanceMatrix(ALLDATA_SSL.x);
% ALLDATA_SSL.K = exp(-1/(2*ALLDATA_SSL.Gamma^2)*dist.^2).*(dist < ALLDATA_SSL.Eps & dist~=0);
% end
%--------------------------------------------------------------------------
function DrawWeights(Fig)
global ALLDATA_SSL
cla(Fig);
if(ALLDATA_SSL.Density~=6)
hold(Fig,'on');
STEP=1; NNZK=nnz(ALLDATA_SSL.K);
% if(NNZK>60000)
% STEP=ceil(NNZK/60000); display(['WARNING: TOO MANY EDGES, WILL ONLY SHOW 1/',num2str(STEP),' of all points']);
% end
xx=zeros(2*ALLDATA_SSL.Num,1);
yy=zeros(2*ALLDATA_SSL.Num,1);
tic
for i=1:STEP:ALLDATA_SSL.Num
indices = find(ALLDATA_SSL.K(i+1:end,i)>0)+i;
NumIndices=length(indices);
xx(1:2:2*NumIndices)=ALLDATA_SSL.x(1,i);
xx(2:2:2*NumIndices)=ALLDATA_SSL.x(1,indices);
yy(1:2:2*NumIndices)=ALLDATA_SSL.x(2,i);
yy(2:2:2*NumIndices)=ALLDATA_SSL.x(2,indices);
plot(Fig,xx(1:2*NumIndices),yy(1:2*NumIndices),'-g');
end
t1=toc
% STEP=1;
% hold(Fig2,'on');
% tic
% for i=1:STEP:ALLDATA_SSL.Num
% indices = find(ALLDATA_SSL.K(i+1:end,i)>0)+i;
% NumIndices=length(indices);
% %plot(Fig,[repmat(ALLDATA_SSL.x(1,i),1,NumIndices); ALLDATA_SSL.x(1,indices)], [repmat(ALLDATA_SSL.x(2,i),1,NumIndices);ALLDATA_SSL.x(2,indices)], '-g');
% if(NumIndices>0)
% plot(Fig2,[repmat(ALLDATA_SSL.x(1,i),1,NumIndices); ALLDATA_SSL.x(1,indices)], [repmat(ALLDATA_SSL.x(2,i),1,NumIndices);ALLDATA_SSL.x(2,indices)], '-g');
% %plot(Fig,[repmat(ALLDATA_SSL.x(1,i),NumIndices,1), ALLDATA_SSL.x(1,indices)], [repmat(ALLDATA_SSL.x(2,i),NumIndices,1), ALLDATA_SSL.x(2,indices)], '-g');
% end
% end
% t2=toc
% hold(Fig2,'off');
hold(Fig,'off');
axis(Fig,'equal');
DrawData(Fig);
end
%--------------------------------------------------------------------------
function numComps = CompComps()
global ALLDATA_SSL
compvec = GD_GetComps(ALLDATA_SSL.K);
numComps = max(compvec);
% disp(['Number of connected components: ', num2str(max(compvec))]);
% numPointsInComps = zeros(numComps,1);
% for i=1:numComps
% numPointsInComps(i)=nnz(compvec==i);
% end
% numPointsInComps=sort(numPointsInComps,'descend');
% disp(['Number of points in the ',num2str(min(numComps,10)),' largest components']);
% for i=1:min(numComps,10)
% disp(['Cluster ',num2str(i),' with ',num2str(numPointsInComps(i)),' points']);
% end
%--------------------------------------------------------------------------
function [EdInBet,EdBet,WBet,WInBet] = getEdgeStatistics()
% EdInBet : Edges inbetween classes
% EdBet : Edges between classes
% WBet : Weights between classes
% WInBet : weights inbetween classes
global ALLDATA_SSL
num_classes=max(ALLDATA_SSL.y);
WInBet =0; EdInBet=0;
for i=1:num_classes
WInBet = WInBet + (ALLDATA_SSL.y==i)'*ALLDATA_SSL.K*(ALLDATA_SSL.y==i);
EdInBet= EdInBet + (ALLDATA_SSL.y==i)'*double(ALLDATA_SSL.K>0)*(ALLDATA_SSL.y==i);
end
WBet = 0.5*(sum(sum(ALLDATA_SSL.K)) - WInBet);
EdBet= 0.5*(nnz(ALLDATA_SSL.K) - EdInBet);
WInBet = 0.5*WInBet;
EdInBet= 0.5*EdInBet;
%--------------------------------------------------------------------------
function UpdateSSL(handles)
global ALLDATA_SSL
global SSLDATA
laplacian=2;
lambda=0;
LabelVector = zeros(ALLDATA_SSL.Num,1);
LabelVector(ALLDATA_SSL.Labeled)=ALLDATA_SSL.y(ALLDATA_SSL.Labeled);
[output,d]=GD_PerformSSL(ALLDATA_SSL.K, LabelVector, max(ALLDATA_SSL.y), laplacian, lambda, SSLDATA.Regul);
[SSLDATA.TestError, SSLDATA.TrainError,SSLDATA.NotLabeled,SSLDATA.Output]=GD_EvalSolution(ALLDATA_SSL.y,output,ALLDATA_SSL.Labeled);
ShowOutput(handles.axes3);
set(handles.TxtTestError,'String',[num2str(100*SSLDATA.TestError,'%2.1f'),'%']);
set(handles.TxtTrainError,'String',[num2str(100*SSLDATA.TrainError,'%2.1f'),'%']);
set(handles.TxtNotLabeled,'String',[num2str(100*SSLDATA.NotLabeled,'%2.1f'),'%']);
% set controls to the initial values
%set(handles.SldNumNeighbors,'Value',(ALLDATA_SSL.NumKNN-ALLDATA_SSL.MinKNN)/(ALLDATA_SSL.MaxKNN-ALLDATA_SSL.MinKNN));
set(handles.SldNumLabels,'Value',(ALLDATA_SSL.NumLabels-ALLDATA_SSL.MinLabels)/(ALLDATA_SSL.MaxLabels-ALLDATA_SSL.MinLabels));
set(handles.SldDim,'Value',(ALLDATA_SSL.Dim-ALLDATA_SSL.MinDim)/(ALLDATA_SSL.MaxDim-ALLDATA_SSL.MinDim));
% set all static text elements
% if(ALLDATA_SSL.GraphType<2)
% set(handles.TxtNumNeighbors,'String',num2str(ALLDATA_SSL.NumKNN));
% else
% set(handles.TxtNumNeighbors,'String',num2str(ALLDATA_SSL.Eps,'%2.2f'));
% end
set(handles.TxtNumLabels,'String',num2str(ALLDATA_SSL.NumLabels));
set(handles.TxtDim,'String',num2str(ALLDATA_SSL.Dim));
%--------------------------------------------------------------------------
function UpdateALL(handles)
global ALLDATA_SSL
global SSLDATA
% clear all graphs
cla(handles.axes2);
cla(handles.axes3);
cla(handles.axes4);
% draw dataset
DrawData(handles.axes4);
% build and draw graph
BuildWeights();
DrawWeights(handles.axes2);
drawnow;
% get and set components
numComps = CompComps();
set(handles.TxtNumComps,'String',num2str(numComps));
[EdInBet,EdBet,WBet,WInBet]=getEdgeStatistics();
% set all static text elements for the graph Statistics
set(handles.TxtWBet,'String',num2str(WBet,'%2.2f'));
set(handles.TxtWInBet,'String',num2str(WInBet,'%2.2f'));
set(handles.TxtEdgeBet,'String',num2str(EdBet,'%2.0d'));
set(handles.TxtEdgeInBet,'String',num2str(EdInBet,'%2.0d'));
if(ALLDATA_SSL.Density<6)
set(handles.TxtTotalNumberPoints,'String',['Total number of points: ',num2str(ALLDATA_SSL.Num)]);
else
set(handles.TxtTotalNumberPoints,'String',['Total number of points: ',num2str(ALLDATA_SSL.Num),', ONLY KNN POSSIBLE !']);
end
% set regularization parameter
set(handles.TxtRegul,'String',num2str(SSLDATA.Regul,'%2.2e'));
UpdateSSL(handles);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -