⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 s6_3_q4_4class.m

📁 Duda《模式分类》第二版第1、3、5章部分课后习题和上机题的解答和程序代码
💻 M
字号:
clear all;close all;clc
% 输入训练参数  
s1 = sqrt(0.01)*randn(2,5) + [0.1*ones(1,5); 0.2*ones(1,5)]; 
s2 = sqrt(0.01)*randn(2,5) + [-0.1*ones(1,5); -0.6*ones(1,5)]; 
s3 = sqrt(0.01)*randn(2,5) + [-0.6*ones(1,5); 0.6*ones(1,5)]; 
s4 = sqrt(0.01)*randn(2,5) + [-0.5*ones(1,5); -0.1*ones(1,5)]; 
train_patterns = [s1,s2,s3,s4];
% 第一类s1的期望输出模式是[-1,-1],s2:[-1,1],s3:[1,-1],s4:[1,1]
train_targets = [-ones(1,10),ones(1,10);-ones(1,5),ones(1,5),-ones(1,5),ones(1,5)];

params = [4 2 1e-3 0.3 20000]; 
%	params - Number of hidden units and output units,
%            Convergence criterion, Convergence	rate, Maximum iterations

% 将纵横范围[-1,1]的正方形区域作为测试区,生成测试点集
% 按照‘从左到右,从下到上’的次序
% 依次将测试区域每个点的坐标输入到坐标序列 test_patterns 中
x1=-1:0.05:1;
x2=x1;
lx=length(x1);
i=1:lx*lx;
ir=mod(i,lx);  ir(find(ir==0))=lx;
ic=ceil(i/lx);
test_patterns = [x1(ir);x2(ic)];

% 利用随机反向传播算法计算网络权值(Wh,Wo)和分类结果(test_targets)
[test_targets, train_verify, Wh, Wo, J] = Backpropagation_Stochastic_MultiOutput(train_patterns, train_targets, test_patterns, params);

% 显示网络各层的连接权值
disp('           ')
disp('input_to_hidden_unit_weights = ')
disp('           ')
for j=1:params(1)
disp('w1j          w2j           w0j')
disp(num2str(Wh(j,:)))
disp('           ')
end
disp('hidden_to_output_unit_weights = ')
disp('           ')
disp(num2str(Wo(1,:)))
disp('           ')

% 训练性能检验
% 首先画出期望输出train-targets
% 蓝线对应输出单元1,绿线对应输出单元2
figure;
subplot(211);
plot(train_targets(1,:),'b');hold on;
plot(train_targets(2,:),'g');hold off;
legend(['Expected N1';'Expected N2'],'Location','BestOutside');
xlabel('Train Patterns'); ylabel('Output Value');
title('Train Targets');
axis([0 20 -2 2]);
% 然后画出实际输出train_verify
subplot(212);
plot(train_verify(1,:),'b--','LineWidth',2);hold on;
plot(train_verify(2,:),'g--','LineWidth',2);
% 将神经元的输出值转换为‘-1’或‘1’,与期望输出曲线比较
train_verify = (train_verify>0)*2-1;
plot(train_verify(1,:),'b');
plot(train_verify(2,:),'g');hold off;
legend(['Ture N1    ';'Ture N2    ';'Rescaled N1';'Rescaled N2'],'Location','BestOutside');
xlabel('Train Patterns'); ylabel('Output Value');
title('Train Verify');
axis([0 20 -2 2]);

% 显示分类效果
% 将神经元的输出值转换为‘-1’或‘1’
test_targets = (test_targets>0)*2-1;
figure;
% 找出每一类所包含的点
c1=find(test_targets(1,:)==-1 & test_targets(2,:)==-1);
c2=find(test_targets(1,:)==-1 & test_targets(2,:)==1);
c3=find(test_targets(1,:)==1 & test_targets(2,:)==-1);
c4=find(test_targets(1,:)==1 & test_targets(2,:)==1);
% 不同类别的点用不同颜色显示
plot(test_patterns(1,c1),test_patterns(2,c1),'r.');hold on;
plot(test_patterns(1,c2),test_patterns(2,c2),'b.');
plot(test_patterns(1,c3),test_patterns(2,c3),'m.');
plot(test_patterns(1,c4),test_patterns(2,c4),'g.');
% 画出训练样本点集
plot(s1(1,:),s1(2,:),'o','MarkerFaceColor','r','MarkerSize',6);
plot(s2(1,:),s2(2,:),'o','MarkerFaceColor','b','MarkerSize',6);
plot(s3(1,:),s3(2,:),'o','MarkerFaceColor','m','MarkerSize',6);
plot(s4(1,:),s4(2,:),'o','MarkerFaceColor','g','MarkerSize',6);hold off;
xlabel('x1');ylabel('x2');title('Test Result')
axis([-1 1 -1 1]);axis square;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -