⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bpmxor.m

📁 带动量项的BP算法程序求解XOR问题
💻 M
字号:
% back propagation algorithm for XOR problem
% batch mode & momentum term
function bpmxor()

clear all
clc

nSampNum = 4;
nSampDim = 2;
% change the Hidden unit number at your will
nHidden = 3;   
nOut = 1;
%-----------------------------------------------
% generate the samples and expected outputs
SampIn = [];
SampOut = [];

for x = 0 : 1
    for y = 0 : 1
        samp = [x;y];
        SampIn = [SampIn,samp];
        SampOut = [SampOut,xor(x,y)];
    end
end
% extended samples 
% SampIn = SampIn + 1;
SampInEx = [SampIn',1*ones(nSampNum,1)]';

%-----------------------------------------------
% initial the weight matrix 
w = 2*(rand(nHidden,nSampDim)-1/2);
b = 2*(rand(nHidden,1)-1/2);
wex = [w,b];

W = 2*(rand(nOut,nHidden)-1/2);
B = 2*(rand(nOut,1)-1/2);
WEX = [W,B];


eb = 0.01;                   % error bound
eta = 0.6;                   % learning rate
mc = 0.8                     % momentum coefficient
maxiter = 10000;             % to be changed 
iteration = 0;

errRec = [];
outRec = [];
% seqRec = [];

for i = 1 : maxiter
    sampex = SampInEx;      % to be changed
    expected = SampOut;     % to be changed 
    
    hp = wex*sampex;        % net input for the hidden layer nodes
    tau = logsig(hp);       % output of the hidden layer nodes
    tauex  = [tau', -1*ones(nSampNum,1)]';      % extended output of the hidden layer
    
    HM = WEX*tauex;         % net input for the output layer nodes
    out = logsig(HM);       % output of the network
    outRec = [outRec,out'];
    
    err = expected - out;
    sse = sumsqr(err);      
    errRec = [errRec,sse];  % save the square errors
    fprintf('sse = %10.8f \n',sse )         % disp(['sum square error is:',num2str(sse)])
    
    iteration = iteration + 1;              % put here for correct iteration times
    if sse<=eb, break,end
    
    % back propagation from output layer
    DELTA = err.*dlogsig(HM,out);            % out = g(HM)
    delta = W' * DELTA.*dlogsig(hp,tau);     % tau = g(hp)
    
    % the difference of the weight sequence
    dWEX = DELTA*tauex';
    dwex = delta*sampex';
    
    %   adjust the weights
    if i == 1
        WEX = WEX + eta * dWEX;
        wex = wex + eta * dwex;
    else   
        WEX = WEX + (1 - mc)*eta*dWEX + mc * dWEXOld;
        wex = wex + (1 - mc)*eta*dwex + mc * dwexOld;
    end
    % save the dw for use of  momentum term
    dWEXOld = dWEX;
    dwexOld = dwex;
     
    % get the W for delta use in iteration
    W  = WEX(:,1:nHidden);
   
end      % end for iteration

% simple display the results
disp(['iteration = ',num2str(iteration)])

W = WEX(:,1:nHidden)
B = WEX(:,1+nHidden)
w = wex(:,1:nSampDim)
b = wex(:,1+nSampDim)

disp(out);  % real output
% draw the error figure
figure
axis on
hold on
grid
[nRow,nCol] = size(errRec);
plot(1:nCol,errRec,'b-','LineWidth',1.5);
legend('SumSqr Errors');
xlabel('iteration times','FontName','Times','FontSize',10);





    
    
    
    
    
    
    
    







⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -