perceptron.m

来自「利用感知机的线性分类功能」· M 代码 · 共 78 行

M
78
字号
clc; clear all;
A=[0 0 1;
   1 2 2;
   1 1 1];   %Pattern 1,each column vector is a sample
Num_A=3;     %Num_A is the number of Pattern 1's sample
B=[2 3 3;
   0 0 1;
   1 1 1];   %Pattern 2,each column vector is a sample
Num_B=3;     %Num_B is the number of Pattern 2's sample
X=[A,B];     %X is all samples for tranning weight value
Num_sample=Num_A+Num_B;  % Num_sample is the total number of samples

W=[ rand(1);
    rand(1);
    rand(1)];  % W is derived from the training weights matrix 
w_new=[0;
       0;
       0];     %w_new is the temporaty vector of weight
ta=1;tb=-1;c=0.5;%ta is the desired output of Pattern 1
                 %tb is the desired output of Pattern 2
                 %c is learning rate
                 
j=1;n=1;         %j is the label of column vector,representing the number j sample
                 %n is the iteration number
while j<=Num_sample
    u=W(1,n)*X(1,j)+W(2,n)*X(2,j)+W(3,n)*X(3,j); %sum a sample with weighted
    y=sign(u);%Use the sign function as the activation function,solve y
    if j<=Num_A   %Desired output is 1,when samples in the pattren 1
        if y==1   
            w_new=W(:,n);
        else 
            w_new(1)=W(1,n)+2*c*(ta-y)*X(1,j);
            w_new(2)=W(2,n)+2*c*(ta-y)*X(2,j);
            w_new(3)=W(3,n)+2*c*(ta-y)*X(3,j);
        end
        W=[W,w_new];
    else    %Desired output is -1,when samples in the pattern 2
        if y==-1
            w_new=W(:,n);
        else
            w_new(1)=W(1,n)+2*c*(tb-y)*X(1,j);
            w_new(2)=W(2,n)+2*c*(tb-y)*X(2,j);
            w_new(3)=W(3,n)+2*c*(tb-y)*X(3,j);
        end
        W=[W,w_new];
    end
    j=j+1;n=n+1;
    if j==Num_sample+1   %When j=7,means all samples is been used to training,if the value of the weights
                         %still need to adjust,it is necessary to start
                         %over again to train(let j=1)
        if W(:,n)==W(:,n-3)
        else
            j=1;
        end
    end
end
%Following is a drawing program
figure(1)
x1=[0 0 1];
y1=[1 2 3];
plot(x1,y1,'o')
hold on
grid on
x2=[2 3 3];
y2=[0 0 1];
plot(x2,y2,'v')   % earmark the pattern samples
hold on
w1=W(1,n);
w2=W(2,n);
w3=W(3,n);
xx=-5:0.1:5;
yy=-(w1*xx)/w2-w3/w2;
plot(xx,yy,'r');  %draw the classified line
xlim([-0.5 3.5]);
ylim([-0.5 3.5]);
title('Result of Pattern Classification'); %add information of figure 
xlabel('x1'),ylabel('x2');
legend('Pattren 1','Pattern 2');

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?