📄 perceptronrule2.m
字号:
function PerceptronRule
%==========================================================
% 单层感知机分类程序
% 感知机的传输函数采用‘硬极限(hardlim)’函数,其输出‘非0即1’
% 对于 M 类 N 维样本集
% 感知机网络的输入数为 N 个,使用的神经元为 M 个
% 即一个感知机神经元对应一类样本的输出
%==========================================================
clear all;close all;clc;
%==========================================================
% 样本集
% 生成随机的样本矩阵
sigma = 0.005; % 样本点分布的标准方差
s1 = repmat([.3 .2]',1,5) + sqrt(sigma)*randn(2,5);
s2 = repmat([.4 -.4]',1,5) + sqrt(sigma)*randn(2,5);
s3 = repmat([-.1 .3]',1,5) + sqrt(sigma)*randn(2,5);
s4 = repmat([-.7 -.5]',1,5) + sqrt(sigma)*randn(2,5);
Nsi = [size(s1,2),size(s2,2),size(s3,2),size(s4,2)]; % 每类样本的个数
Patterns = [s1,s2,s3,s4];
[Ni,Np] = size(Patterns); % Ni - 输入元数 ; Np - 样本总数
%==========================================================
% 目标集
No = 2; % 神经元的个数
Targets = [zeros(1,Nsi(1)+Nsi(2)),ones(1,Nsi(3)+Nsi(4));...
zeros(1,Nsi(1)),ones(1,Nsi(2)),zeros(1,Nsi(3)),ones(1,Nsi(4))];
%==========================================================
% 初始化网络参数
Bias = 1; % 偏置值
Wio = rand(No,Ni+1); % 输入与神经元之间的连接权值
Wmat{1} = Wio; % 存放每次训练所得的连接权值 Wio
Ek = ones(No,1); % 单次训练误差
Ev = ones(No,Np); % 检验误差:将单次训练所得的连接权值 Wio 应用于感知机,对所有样本集进行检验
Err = []; % 记录训练误差变化情况
MaxIt = 2000; % 最大迭代数
iter = 1;
%==========================================================
% 训练感知机
% 结束条件:
% (1)当感知机对所有样本都能正确分类时,检验误差矩阵为全零矩阵,
% 即 isequal(Ev,zeros(No,Np)) 为真,此时可结束训练
% (2)当训练次数达到最大迭代数时,结束训练
while ~isequal(Ev,zeros(No,Np)) & iter<MaxIt
% 将样本集的每个样本依次输入到感知机中
k = mod(iter,Np); % 每次迭代输入一个样本
if k==0 % 当所有样本都已输入时,重新输入第1个样本,如此循环
k = Np;
end
Xk = Patterns(:,k); % 输入
Tk = Targets(:,k); % 目标
Ok = perceptron(Xk,Bias,Wio); % 输出
Ek = Tk - Ok; % 单次误差
Wio = Wio + Ek*[Xk;Bias]'; % 感知机学习规则( Hagan《神经网络设计》P47 )
Ev = Targets - perceptron(Patterns,Bias,Wio); % 整体误差
iter = iter+1;
Wmat{iter} = Wio;
Err = [Err,length(unique(ceil(find(Ev~=0)/No)))];
end
%==========================================================
% 显示训练所用次数
disp(['All patterns are classified correctly after ' num2str(iter) ' iterations.']);
%==========================================================
% 画出训练过程和训练结果
figure('Name','Training Process')
Nplot = 6;
Ip = fix(linspace (1,iter,Nplot));
for np =1:Nplot
subplot(2,3,np);
It = Ip(np);
Wio = Wmat{It};
plotresult(s1,s2,s3,s4,Bias,Wio,It);
end
figure('Name','Train Result')
subplot(121);
plot(Err); grid on;
xlabel('Iterations'); ylabel('Number of errors');
title('Training Errors');
axis([0 iter+20 0 21]);axis square;
subplot(122);
plotresult(s1,s2,s3,s4,Bias,Wio,iter);
%==========================================================
% 感知机神经元
%==========================================================
function y=perceptron(x,Bias,Wio)
[r,c] = size(x);
xi = [x;Bias*ones(1,c)]; % 加入偏置
net_in = Wio*xi; % 对加权输入求和
y = hardlim(net_in);
%==========================================================
% 绘图函数
% 功能:1、样本集的分布;2、判定边界;
% 3、分类区域; 4、X轴,Y轴。
%==========================================================
function plotresult(s1,s2,s3,s4,Bias,Wio,It)
% 画出各样本集的分布
plot(s1(1,:),s1(2,:),'o','MarkerFaceColor','r','MarkerSize',6); hold on;
plot(s2(1,:),s2(2,:),'o','MarkerFaceColor','g','MarkerSize',6);
plot(s3(1,:),s3(2,:),'o','MarkerFaceColor','b','MarkerSize',6);
plot(s4(1,:),s4(2,:),'o','MarkerFaceColor','m','MarkerSize',6);
% 画出判定边界 ( w'p + b = 0 )
% 即 w1*x1 + w2*x2 + w3*bias = 0
x = -1:0.1:1;
y1 = -(Wio(1,3).*Bias+Wio(1,1).*x)./Wio(1,2);
y2 = -(Wio(2,3).*Bias+Wio(2,1).*x)./Wio(2,2);
plot(x,y1,'r','LineWidth',2);
plot(x,y2,'g','LineWidth',2);
% 画出分类区域
% 按照‘从左到右,从下到上’的次序
% 依次将决策区域每个点的坐标输入到坐标序列 test_patterns 中
x1 = -1:0.05:1;
x2 = x1;
lx = length(x1);
i = 1:lx*lx;
ir = mod(i,lx); ir(find(ir==0)) = lx;
ic = ceil(i/lx);
test_patterns = [x1(ir);x2(ic)];
tested = perceptron(test_patterns,Bias,Wio);
% 找出属于同一类的点
idx1 = find( tested(1,:)==0 & tested(2,:)==0 );
idx2 = find( tested(1,:)==0 & tested(2,:)==1 );
idx3 = find( tested(1,:)==1 & tested(2,:)==0 );
idx4 = find( tested(1,:)==1 & tested(2,:)==1 );
% 画出可确认类别的区域
c1 = test_patterns(:,idx1);
c2 = test_patterns(:,idx2);
c3 = test_patterns(:,idx3);
c4 = test_patterns(:,idx4);
plot(c1(1,:),c1(2,:),'r.');
plot(c2(1,:),c2(2,:),'g.');
plot(c3(1,:),c3(2,:),'b.');
plot(c4(1,:),c4(2,:),'m.');
% 画出X轴,Y轴
plot(x1,zeros(1,lx),'k--','LineWidth',2);
plot(zeros(1,lx),x1,'k--','LineWidth',2);
xlabel('x1');ylabel('x2');
title(['After ' num2str(It) ' iterations.']);
axis([-1 1 -1 1]);axis square;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -