📄 probnn.m
字号:
function [winout,routput,poutput,winner,mispat] = probnn(train,ncls,clsize,clsinfo,pred,pinfo,smooth)
% PROBNN probabalistic neural network
% Author: Ron Shaffer
% Naval Research Laboratory
% shaffer@ccf.nrl.navy.mil
% Revisions: 1/22/96 Version 1.0 Original code
% 1/23/96 Version 1.1 Vectorized code 400% improvement in spped
% 1/24/96 Version 1.2 Improved output for mispat and fixed bug
% caused when the test set had a different number
% of patterns than the training set
% 4/5/96 Version 1.3 Modified user input so that the training and
% prediction set do not have to be input in order.
% Speed improvement of 25% over version 1.2
%
% [winout,routput,poutput,winner,mispat] = probnn(train,ncls,clsize,clsinfo,pred,pinfo,smooth)
%
% winout: predicted class for each test set pattern
% routput: matrix of raw outputs (number of patterns in test X ncls)
% poutput: matrix of probabilities (number of patterns in test x ncls)
% winner: total number correct for each class
% mispat: raw network outputs for misclassified patterns
% train: training set patterns (number of pattern x number of sensors)
% ncls: number of classes (number of outputs)
% clsinfo: vector containing the classication of each pattern in training
% clsize: vector containing the number of members in the training set of each class
% pred: test set patterns (number of patterns x number of sensors)
% pinfo: vector containing the classication of each pattern in prediction
% smooth: smoothing factor
% NOTE: Use this code at your own risk because the author assumes no liability!
%
% set constants
%
[npat_t,ndim] = size(train);
[npat_p,junk] = size(pred);
nhcel = npat_t;
smooth_sqr = smooth * smooth;
%
% normalize training and prediction patterns
%
workt = normr(train);
workp = normr(pred);
%
% move training data to hidden units (i.e., training)
%
hcel = workt;
%
% now perform prediction by propagating prediction patterns through network
%
%
% first perform d.p. calculation
%
hold = hcel * workp';
propmat = exp((hold - 1)/smooth_sqr);
hold = hold';
propmat = propmat';
%
for i = 1:npat_p
% summation for each output layer
output(i,1:ncls) = zeros(size(1:ncls));
for j = 1:nhcel
output(i,clsinfo(j)) = output(i,clsinfo(j)) + propmat(i,j);
end
end
%
% compute probability for each class
%
raw = output;
%
% Compute mean output for each class by dividing by class size
% This calculation corrects for unequal class sizes.
%
for i = 1:npat_p
output(i,1:ncls) = output(i,1:ncls) ./ clsize;
end
holdvector = sum(output')';
for i = 1:npat_p
output(i,1:ncls) = output(i,1:ncls) / holdvector(i);
end
%
% compute winner
%
hold1 = output';
[hold2,winout] = max(hold1);
winout = winout';
%
%
% now compute the percentage correct for each output class
%
count = 0;
missedct = 0;
ncorrect(1:ncls) = zeros(size(1:ncls));
class_count(1:ncls) = zeros(size(1:ncls));
for i = 1:npat_p
if winout(i) == pinfo(i)
ncorrect(pinfo(i)) = ncorrect(pinfo(i)) + 1;
class_count(pinfo(i)) = class_count(pinfo(i)) + 1;
else
missedct = missedct + 1;
mispat(missedct,1) = i;
mispat(missedct,2) = pinfo(i);
mispat(missedct,3:ncls+2) = raw(i,1:ncls);
mispat(missedct,(ncls+3):(ncls+2+ndim)) = workp(i,1:ndim);
class_count(pinfo(i)) = class_count(pinfo(i)) + 1;
end
end
end
winner = ncorrect;
routput = raw;
poutput = output;
temp = (winner ./ class_count) .* (ones(size(1:ncls)) * 100);
for i = 1:ncls
fprintf('Class %d %8.4f correct \n',i,temp(i));
end
%
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -