⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.m

📁 最新的模式识别分类工具箱,希望对朋友们有用!
💻 M
字号:
function D = EM(train_features, train_targets, Ngaussians, region)% Classify using the expectation-maximization algorithm% Inputs:% 	features	- Train features%	targets	- Train targets%  Ngaussians - Number for Gaussians for each class (vector)%	region	- Decision region vector: [-x x -y y number_of_points]%% Outputs%	D			- Decision sufrace[Nclasses, classes]  = find_classes(train_targets); %Number of classes in targetsNalpha				   = Ngaussians;						 %Number of Gaussians in each class%The initial guess is based on k-means preprocessing. If it does not converge after%max_iter iterations, a random guess is used.max_iter   				= 100;weights					= zeros(2,max(Ngaussians));thetas					= zeros(2,max(Ngaussians),size(train_features,1),size(train_features,1));means						= zeros(2,max(Ngaussians),size(train_features,1));disp('Using k-means for initial guess')in0						= find(train_targets==0);in1						= find(train_targets==1);[initial_m0, targets, labels]	= k_means(train_features(:,in0),train_targets(:,in0),Ngaussians(1),region);for i = 1:Ngaussians(1),   gauss_labels = find(labels==i);   weights(1,i) = length(gauss_labels) / length(labels);   thetas(1,i,:,:) = diag(std(train_features(:,in0(gauss_labels))'));end[initial_m1, targets, labels]	= k_means(train_features(:,in1),train_targets(:,in1),Ngaussians(2),region);for i = 1:Ngaussians(2),   gauss_labels = find(labels==i);   weights(2,i) = length(gauss_labels) / length(labels);   thetas(2,i,:,:) = diag(std(train_features(:,in1(gauss_labels))'));end means(1,1:Ngaussians(1),:) = initial_m0';means(2,1:Ngaussians(2),:) = initial_m1';%Estimate mean and covariance for each class for c = 1:Nclasses,    train	   = find(train_targets == classes(c));    if (Ngaussians(c) == 1),       thetas(c,1,:,:) = sqrtm(cov(train_features(:,train)'));       means(c,1,:)   = mean(train_features(:,train)');    else           theta       = squeeze(thetas(c,:,:,:));       old_theta   = zeros(size(theta)); 		%Used for the stopping criterion       iter			= 0;									%Iteration counter		n			 	= length(train);					%Number of training points		qi			   = zeros(Nalpha(c),n);	   	%This will hold qi's       P				= zeros(1,Nalpha(c));       		while sum(sum(sum(abs(theta-old_theta)))) > 1e-10,       	old_theta = theta;              	%Calculating Qi's		   for t = 1:n,       	   data  = train_features(:,train(t));             for k = 1:Nalpha(c),                P(k) = weights(c,k) * p_single(data, squeeze(means(c,k,:)), squeeze(theta(k,:,:)));				end                               	for i = 1:Nalpha(c),             	qi(i,t) = P(i) / sum(P);		      end		   end              	%Calculating mu's       	for i = 1:Nalpha(c),          	means(c,i,:) = sum((train_features(:,train).*(ones(2,1)*qi(i,:)))')/sum(qi(i,:)');        end       		   %Calculating sigma's       	%A bit different from the handouts, but much more efficient       	for i = 1:Nalpha(c),		      data_vec = train_features(:,train);              data_vec = data_vec - squeeze(means(c,i,:)) * ones(1,n);          	  data_vec = data_vec .* (ones(2,1) * sqrt(qi(i,:)));		      theta(i,:,:) = sqrt(abs(cov(data_vec')*n/sum(qi(i,:)')));       	end       		   %Calculating alpha's       	weights(c,1:Ngaussians(c)) = 1/n*sum(qi');              	iter = iter + 1;          disp(['Iteration: ' num2str(iter)])                    if (iter > max_iter),             theta = randn(size(theta));             iter  = 0;             disp('Redrawing weights.')          end                 end              thetas(c,:,:,:) = theta;   endend%Find decision region (********** Made for only 2 classes ************)p0				= length(find(train_targets == 0))/length(train_targets);%If there is only one gaussian in a class, squeeze will wreck it's format, so fixing is neededm0  = squeeze(means(1,1:Ngaussians(1),:));m1  = squeeze(means(2,1:Ngaussians(2),:));if (size(m0,2) == 1),    m0 = m0';endif (size(m1,2) == 1),    m1 = m1';endD	= decision_region(m0, squeeze(thetas(1,1:Ngaussians(1),:,:)), weights(1,1:Ngaussians(1),:), ...                      m1, squeeze(thetas(2,1:Ngaussians(2),:,:)), weights(2,1:Ngaussians(2),:), p0, region);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -