📄 competitive_learning.m
字号:
function [patterns, targets, label, W] = Competitive_learning(train_patterns, train_targets, params, plot_on)
% Perform preprocessing using a competitive learning network
% Inputs:
% patterns - Train patterns
% targets - Train targets
% params - [Number of partitions, learning rate]
% plot_on - Plot while performing processing?
%
% Outputs
% patterns - New patterns
% targets - New targets
% label - The labels given for each of the original patterns
% W - Weights matrice
max_iter = 1000;
[c, r] = size(train_patterns);
[N, eta] = process_params(params);
decay = 0.99;
%Preprocessing:
% x_i <- {x_i, 1}
x = [train_patterns ; ones(1,r)];
%x_i <- x_i./||x_i||
x = x ./ (ones(c+1,1) * sqrt(sum(x.^2)));
%Initialize the W's
i = randperm(r);
W = x(:,i(1:N));
for i = 1:max_iter,
%Randomally order the patterns
order = randperm(r);
change= 0;
for k = 1:r,
J = W'*x(:,order(k));
j = find(J == max(J));
old_W = W(:,j);
%W_j <- W_j + eta*x
W(:,j) = W(:,j) + eta*x(:,order(k));
%W_j <- W_j/||W_j||
W(:,j) = W(:,j) / sqrt(sum(W(:,j).^2));
change = change + sum(abs(W(:,j) - old_W));
if (plot_on > 0),
%Assign each of the patterns to a center
dist = W'*x;
[m, label] = max(dist);
centers = zeros(c,N);
for i = 1:N,
in = find(label == i);
if ~isempty(in)
centers(:,i) = mean(x(1:2,find(label==i))')';
else
centers(:,i) = nan;
end
end
%Plot the centers during the process
plot_process(centers, plot_on)
end
end
eta = eta * decay;
if (change/r < 1e-4),
break
end
end
if (i == max_iter),
disp(['Maximum iteration (' num2str(max_iter) ') reached']);
else
disp(['Finished after ' num2str(i) ' iterations.'])
end
%Assign each of the patterns to a center
dist = W'*x;
[m, label] = max(dist);
patterns = zeros(c,N);
for i = 1:N,
in = find(label == i);
if ~isempty(in)
patterns(:,i) = mean(x(1:end-1,find(label==i))')';
else
patterns(:,i) = nan;
end
end
%Label the points
[m,label] = min(dist);
targets = zeros(1,N);
Uc = unique(train_targets);
for i = 1:N,
n = hist(train_targets(:,find(label == i)), Uc);
[m, max_l] = max(n);
targets(i) = Uc(max_l);
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -