📄 hmemenu.m
字号:
else
error('Invalid data format: not a .mat or a .txt file')
end
if (size(train_data,2)~=cov_dim+res_dim)&(type==1),
error(['Invalid data matrix size: ', num2str(size(train_data,2)), ' columns rather than ',...
num2str(cov_dim+res_dim),'!']);
elseif (size(train_data,2)~=cov_dim+1)&(type==2),
error(['Invalid data matrix size: ', num2str(size(train_data,2)), ' columns rather than ',...
num2str(cov_dim+1),'!']);
elseif (~isempty(find(ismember(intersect([train_data(:,end)' 1:res_dim],...
train_data(:,end)'),[1:res_dim])==0)))&(type==2),
error('Invalid class label');
end
ntrain=size(train_data,1);
train_d=train_data(:,1:cov_dim);
if type==2,
train_t=zeros(ntrain, res_dim);
for m=1:res_dim,
train_t((find(train_data(:,end)==m))',m)=1;
end
else
train_t=train_data(:,cov_dim+1:end);
end
disp(' ')
% ------------------------------------------------------------------------------------------------
% Loading test data ------------------------------------------------------------------------------
% ------------------------------------------------------------------------------------------------
disp('(If you don''t want to specify a test-set press ''return'' only)');
test_path=input('Insert the complete (with extension) path of the test data file:\n >> ','s');
if ~isempty(test_path),
if ~isempty(findstr('.mat',test_path)),
ap=load(test_path); app=fieldnames(ap); test_data=eval(['ap.', app{1,1}]);
clear ap app;
elseif ~isempty(findstr('.txt',test_path)),
test_data=load(test_path, '-ascii');
else
error('Invalid data format: not a .mat or a .txt file')
end
if (size(test_data,2)~=cov_dim)&(size(test_data,2)~=cov_dim+res_dim)&(type==1),
error(['Invalid data matrix size: ', num2str(size(test_data,2)), ' columns rather than ',...
num2str(cov_dim+res_dim), ' or ', num2str(cov_dim), '!']);
elseif (size(test_data,2)~=cov_dim)&(size(test_data,2)~=cov_dim+1)&(type==2),
error(['Invalid data matrix size: ', num2str(size(test_data,2)), ' columns rather than ',...
num2str(cov_dim+1), ' or ', num2str(cov_dim), '!']);
elseif (~isempty(find(ismember(intersect([test_data(:,end)' 1:res_dim],...
test_data(:,end)'),[1:res_dim])==0)))&(type==2)&(size(test_data,2)==cov_dim+1),
error('Invalid class label');
end
ntest=size(test_data,1);
test_d=test_data(:,1:cov_dim);
if (type==2)&(size(test_data,2)>cov_dim),
test_t=zeros(ntest, res_dim);
for m=1:res_dim,
test_t((find(test_data(:,end)==m))',m)=1;
end
elseif (type==1)&(size(test_data,2)>cov_dim),
test_t=test_data(:,cov_dim+1:end);
end
disp(' ');
end
else
clc
disp('----------------------------------------------------');
disp(' Specify the Input ');
disp('----------------------------------------------------');
disp(' ')
ntrain = input('Insert the number of examples in training (<500): ');
if (isempty(ntrain)|(floor(ntrain)~=ntrain)|(ntrain<=0)|(ntrain>500)),
error(['Invalid value: ', num2str(ntrain), ' is not a positive integer <500!']);
end
disp(' ')
test_path='toy';
ntest = input('Insert the number of examples in test (<500): ');
if (isempty(ntest)|(floor(ntest)~=ntest)|(ntest<=0)|(ntest>500)),
error(['Invalid value: ', num2str(ntest), ' is not a positive integer <500!']);
end
if type==2,
cov_dim=2;
res_dim=3;
seed = 42;
[train_d, ntrain1, ntrain2, train_t]=gen_data(ntrain, seed);
for m=1:ntrain
q=[]; q = find(train_t(m,:)==1);
train_data(m,:)=[train_d(m,:) q];
end
[test_d, ntest1, ntest2, test_t]=gen_data(ntest);
for m=1:ntest
q=[]; q = find(test_t(m,:)==1);
test_data(m,:)=[test_d(m,:) q];
end
else
cov_dim=1;
res_dim=1;
global HOME
%%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
load([HOME '/examples/static/Misc/mixexp_data.txt'], '-ascii');
%%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
train_data = mixexp_data(1:ntrain, :);
train_d=train_data(:,1:cov_dim); train_t=train_data(:,cov_dim+1:end);
test_data = mixexp_data(ntrain+1:ntrain+ntest, :);
test_d=test_data(:,1:cov_dim);
if size(test_data,2)>cov_dim,
test_t=test_data(:,cov_dim+1:end);
end
end
end
% Set the nodes dimension-----------------------------------
if num_glevel>0,
nodes_info(2,2:num_glevel+1)=branch_fact;
end
nodes_info(2,1)=cov_dim; nodes_info(2,end)=res_dim;
%-----------------------------------------------------------
% Prepare the training data for the learning engine---------
%-----------------------------------------------------------
cases = cell(size(nodes_info,2), ntrain);
for m=1:ntrain,
cases{1,m}=train_data(m,1:cov_dim)';
cases{end,m}=train_data(m,cov_dim+1:end)';
end
%-----------------------------------------------------------------------------------------------------
[bnet onodes]=hme_topobuilder(nodes_info);
engine = jtree_inf_engine(bnet, onodes);
clc
disp('---------------------------------------------------------------------');
disp(' L E A R N I N G ');
disp('---------------------------------------------------------------------');
disp(' ')
ll = 0;
for l=1:ntrain
scritta=['example number: ', int2str(l),'---------------------------------------------'];
disp(scritta);
ev = cases(:,l);
[engine, loglik] = enter_evidence(engine, ev);
ll = ll + loglik;
end
disp(' ')
disp(['Log-likelihood before learning: ', num2str(ll)]);
disp(' ')
disp('(Press any key to continue)');
pause
%-----------------------------------------------------------
clc
disp('---------------------------------------------------------------------');
disp(' L E A R N I N G ');
disp('---------------------------------------------------------------------');
disp(' ')
max_em_iter=input('Insert the maximum number of the EM algorithm iterations: ');
if (isempty(max_em_iter)|(floor(max_em_iter)~=max_em_iter)|(max_em_iter<=1)),
error(['Invalid value: ', num2str(ntest), ' is not a positive integer >1!']);
end
disp(' ')
disp(['Log-likelihood before learning: ', num2str(ll)]);
disp(' ')
[bnet2, LL2] = learn_params_em(engine, cases, max_em_iter);
disp(' ')
fprintf('HME: loglik before learning %f, after %d iters %f\n', ll, length(LL2), LL2(end));
disp(' ')
disp('(Press any key to continue)');
pause
%-----------------------------------------------------------------------------------
% Classification problem: plot data & decision boundaries if the input data size = 2
% Regression problem: plot data & prediction if the input data size = 1
%-----------------------------------------------------------------------------------
if (type==2)&(nodes_info(2,1)==2)&(~isempty(test_path)),
fh1=hme_class_plot(bnet2, nodes_info, train_data, test_data);
disp(' ');
disp('(See the figure)');
elseif (type==2)&(nodes_info(2,1)==2)&(isempty(test_path)),
fh1=hme_class_plot(bnet2, nodes_info, train_data);
disp(' ');
disp('(See the figure)');
elseif (type==1)&(nodes_info(2,1)==1)&(~isempty(test_path)),
fh1=hme_reg_plot(bnet2, nodes_info, train_data, test_data);
disp(' ');
disp('(See the figure)');
elseif (type==1)&(nodes_info(2,1)==1)&(isempty(test_path)),
fh1=hme_reg_plot(bnet2, nodes_info, train_data);
disp(' ')
disp('(See the figure)');
end
%-----------------------------------------------------------------------------------
% Classification problem: plot confusion matrix
%-----------------------------------------------------------------------------------
if (type==2)
ztrain=fhme(bnet2, nodes_info, train_d, size(train_d,1));
[Htrain, trainRate]=confmat(ztrain, train_t); % CM on the training set
fh2=figure('Name','Confusion matrix', 'MenuBar', 'none', 'NumberTitle', 'off');
if (~isempty(test_path))&(size(test_data,2)>cov_dim),
ztest=fhme(bnet2, nodes_info, test_d, size(test_d,1));
[Htest, testRate]=confmat(ztest, test_t); % CM on the test set
subplot(1,2,1);
end
plotmat(Htrain,'b','k',12)
tick=[0.5:1:(0.5+nodes_info(2,end)-1)];
set(gca,'XTick',tick)
set(gca,'YTick',tick)
grid('off')
ylabel('True')
xlabel('Prediction')
title(['Confusion Matrix: training set (' num2str(trainRate(1)) '%)'])
if (~isempty(test_path))&(size(test_data,2)>cov_dim),
subplot(1,2,2)
plotmat(Htest,'b','k',12)
set(gca,'XTick',tick)
set(gca,'YTick',tick)
grid('off')
ylabel('True')
xlabel('Prediction')
title(['Confusion Matrix: test set (' num2str(testRate(1)) '%)'])
end
disp(' ')
disp('(Press any key to continue)');
pause
end
%-----------------------------------------------------------------------------------
% Regression & Classification problem: calculate the predictions & plot the LL trace
%-----------------------------------------------------------------------------------
train_result=fhme(bnet2,nodes_info,train_d,size(train_d,1));
if ~isempty(test_path),
test_result=fhme(bnet2,nodes_info,test_d,size(test_d,1));
end
fh3=figure('Name','Log-likelihood trace', 'MenuBar', 'none', 'NumberTitle', 'off')
plot(LL2,'-ro',...
'MarkerEdgeColor','k',...
'MarkerFaceColor',[1 1 0],...
'MarkerSize',4)
title('Log-likelihood trace')
%-----------------------------------------------------------------------------------
% Regression & Classification problem: save the predictions
%-----------------------------------------------------------------------------------
clc
disp('------------------------------------------------------------------');
disp(' Save the results ');
disp('------------------------------------------------------------------');
disp(' ')
%-----------------------------------------------------------------------------------
save_quest_m=input('Do you want to save the HME model (Y/N)? [Y default]: ', 's');
if isempty(save_quest_m),
save_quest_m='Y';
end
if ~findstr(save_quest_m, ['Y', 'N']), error('Invalid input'); end
if save_quest_m=='Y',
disp(' ');
m_save=input('Insert the complete path for save the HME model (.mat):\n >> ', 's');
if isempty(m_save), error('You must specify a path!'); end
save(m_save, 'bnet2');
end
%-----------------------------------------------------------------------------------
disp(' ')
save_quest=input('Do you want to save the HME predictions (Y/N)? [Y default]: ', 's');
disp(' ')
if isempty(save_quest),
save_quest='Y';
end
if ~findstr(save_quest, ['Y', 'N']), error('Invalid input'); end
if save_quest=='Y',
tr_save=input('Insert the complete path for save the training data prediction (.mat):\n >> ', 's');
if isempty(tr_save), error('You must specify a path!'); end
save(tr_save, 'train_result');
if ~isempty(test_path),
disp(' ')
te_save=input('Insert the complete path for save the test data prediction (.mat):\n >> ', 's');
if isempty(te_save), error('You must specify a path!'); end
save(te_save, 'test_result');
end
end
clc
disp('----------------------------------------------------');
disp(' B Y E ! ');
disp('----------------------------------------------------');
pause(2)
%clear
clc
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -