📄 gmmtrain.html
字号:
0053 <span class="keyword">end</span>0054 0055 <span class="comment">% ====== Start EM iteration</span>0056 logProb = zeros(gmmTrainParam.maxIteration, 1); <span class="comment">% Array for objective function</span>0057 <span class="keyword">if</span> gmmTrainParam.dispOpt & dim==2, <a href="#_sub1" class="code" title="subfunction displayGmm(data, gmmParam)">displayGmm</a>(data, gmmParam); <span class="keyword">end</span>0058 0059 <span class="keyword">for</span> i = 1:gmmTrainParam.maxIteration0060 <span class="comment">% ====== Expectation step:</span>0061 <span class="comment">% P(i,j) is the probability of data(:,j) to the i-th Gaussian</span>0062 [logProbs, P]=<a href="gmmEval.html" class="code" title="function [logProb, gaussianProb] = gmmEval(data, gmmParam);">gmmEval</a>(data, gmmParam);0063 <span class="comment">% [logProbs, P]=gmmEvalMex(data, M, V, W);</span>0064 logProb(i)=sum(logProbs);0065 <span class="keyword">if</span> gmmTrainParam.dispOpt, fprintf(<span class="string">'\tGMM iteration: %d/%d, log prob. = %f\n'</span>, i-1, gmmTrainParam.maxIteration, logProb(i)); <span class="keyword">end</span>0066 W = [gmmParam.w]';0067 PW = repmat(W, 1, dataNum).*P;0068 sumPW = sum(PW, 1);0069 <span class="keyword">if</span> any(sumPW==0), keyboard; error(<span class="string">'some entries of sumPW==0! You need to make sure each point is at least covered by a Gaussian!'</span>); <span class="keyword">end</span>0070 BETA=PW./repmat(sumPW, gaussianNum, 1); <span class="comment">% BETA(i,j) is beta_i(x_j)</span>0071 sumBETA=sum(BETA,2);0072 0073 <span class="comment">% ====== Maximization step: eqns (2.96) to (2.98) from Bishop p.67:</span>0074 <span class="comment">% === Compute mu</span>0075 M = (data*BETA')./repmat(sumBETA', dim, 1);0076 <span class="comment">% Distribute the parameters to gmmParam</span>0077 meanCell=mat2cell(M, dim, ones(1, gaussianNum));0078 [gmmParam.mu]=deal(meanCell{:});0079 <span class="comment">% === Compute w</span>0080 W = (1/dataNum)*sumBETA; <span class="comment">% (2.98)</span>0081 wCell=mat2cell(W', 1, ones(1, gaussianNum));0082 [gmmParam.w]=deal(wCell{:});0083 <span class="comment">% === Compute sigma</span>0084 DISTSQ = pairwiseSqrDistance(M, data); <span class="comment">% Distance of M to data</span>0085 <span class="keyword">if</span> prod(size(gmmParam(1).sigma))==1 <span class="comment">% identity covariance matrix times a constant for each Gaussian</span>0086 V = max((sum(BETA.*DISTSQ, 2)./sumBETA)/dim, gmmTrainParam.minVariance); <span class="comment">% (2.97)</span>0087 <span class="keyword">for</span> j=1:gaussianNum0088 gmmParam(j).sigma=V(j);0089 <span class="keyword">end</span>0090 <span class="keyword">elseif</span> prod(size(gmmParam(1).sigma))==dim <span class="comment">% diagonal covariance matrix for each Gaussian</span>0091 <span class="comment">% This segment remains to be double checked</span>0092 <span class="keyword">for</span> j=1:gaussianNum0093 dataMinusMu = data-repmat(gmmParam(j).mu, 1, dataNum);0094 weight = repmat(BETA(j,:), dim, 1);0095 gmmParam(j).sigma=sum((weight.*dataMinusMu).*dataMinusMu, 2)/sum(BETA(j,:));0096 <span class="keyword">end</span>0097 <span class="keyword">else</span> <span class="comment">% full covariance matrix for each Gaussian</span>0098 <span class="keyword">for</span> j=1:gaussianNum0099 dataMinusMu = data-repmat(gmmParam(j).mu, 1, dataNum);0100 weight = repmat(BETA(j,:), dim, 1);0101 gmmParam(j).sigma=(weight.*dataMinusMu)*dataMinusMu'/sum(BETA(j,:));0102 <span class="keyword">end</span>0103 <span class="keyword">end</span>0104 <span class="comment">% === Animation</span>0105 <span class="keyword">if</span> gmmTrainParam.dispOpt & dim==2, <a href="#_sub1" class="code" title="subfunction displayGmm(data, gmmParam)">displayGmm</a>(data, gmmParam); <span class="keyword">end</span>0106 0107 <span class="comment">% ====== Check stopping criterion</span>0108 <span class="keyword">if</span> i>1, <span class="keyword">if</span> logProb(i)-logProb(i-1)<gmmTrainParam.minImprove, <span class="keyword">break</span>; <span class="keyword">end</span>; <span class="keyword">end</span>0109 <span class="keyword">end</span>0110 [logProbs, P]=<a href="gmmEval.html" class="code" title="function [logProb, gaussianProb] = gmmEval(data, gmmParam);">gmmEval</a>(data, gmmParam);0111 logProb(i)=sum(logProbs);0112 0113 <span class="keyword">if</span> gmmTrainParam.dispOpt, fprintf(<span class="string">'\tGMM total iteration count = %d, log prob. = %f\n'</span>,i, logProb(i)); <span class="keyword">end</span>0114 logProb(i+1:gmmTrainParam.maxIteration) = [];0115 0116 <span class="comment">% ====== Subfunctions ======</span>0117 <a name="_sub1" href="#_subfunctions" class="code">function displayGmm(data, gmmParam)</a>0118 <span class="comment">% Display function for EM algorithm</span>0119 figureH=findobj(0, <span class="string">'tag'</span>, mfilename);0120 <span class="keyword">if</span> isempty(figureH)0121 figureH=figure;0122 set(figureH, <span class="string">'tag'</span>, mfilename);0123 plot(data(1,:), data(2,:),<span class="string">'.r'</span>); axis image0124 <span class="keyword">for</span> i=1:length(gmmParam)0125 [xData, yData]=<a href="#_sub2" class="code" title="subfunction [xData, yData]=halfHeightContour(gaussianParam)">halfHeightContour</a>(gmmParam(i));0126 circleH(i)=line(xData, yData, <span class="string">'color'</span>, <span class="string">'k'</span>, <span class="string">'linewidth'</span>, 3);0127 <span class="keyword">end</span>0128 set(circleH, <span class="string">'tag'</span>, <span class="string">'circleH'</span>, <span class="string">'erasemode'</span>, <span class="string">'xor'</span>);0129 <span class="keyword">else</span>0130 circleH=findobj(figureH, <span class="string">'tag'</span>, <span class="string">'circleH'</span>);0131 <span class="keyword">for</span> i=1:length(gmmParam)0132 [xData, yData]=<a href="#_sub2" class="code" title="subfunction [xData, yData]=halfHeightContour(gaussianParam)">halfHeightContour</a>(gmmParam(i));0133 set(circleH(i), <span class="string">'xdata'</span>, xData, <span class="string">'ydata'</span>, yData);0134 <span class="keyword">end</span>0135 drawnow0136 <span class="keyword">end</span>0137 0138 <span class="comment">% ====== Obtain the contour data at half height of an Gaussian</span>0139 <a name="_sub2" href="#_subfunctions" class="code">function [xData, yData]=halfHeightContour(gaussianParam)</a>0140 dim=length(gaussianParam.mu);0141 theta=linspace(-pi, pi, 21);0142 <span class="keyword">if</span> prod(size(gaussianParam.sigma))==1 <span class="comment">% identity covariance matrix times a constant for each Gaussian</span>0143 r=sqrt(2*log(2)*gaussianParam.sigma); <span class="comment">% Gaussian reaches it's 50% height at this distance from the mean</span>0144 xData=r*cos(theta)+gaussianParam.mu(1);0145 yData=r*sin(theta)+gaussianParam.mu(2);0146 <span class="keyword">elseif</span> prod(size(gaussianParam.sigma))==dim <span class="comment">% diagonal covariance matrix for each Gaussian</span>0147 r1=sqrt(2*log(2)*gaussianParam.sigma(1));0148 r2=sqrt(2*log(2)*gaussianParam.sigma(2));0149 xData=r1*cos(theta)+gaussianParam.mu(1);0150 yData=r2*sin(theta)+gaussianParam.mu(2);0151 <span class="keyword">else</span> <span class="comment">% full covariance matrix for each Gaussian</span>0152 [V, D]=eig(gaussianParam.sigma);0153 r1=sqrt(2*log(2)*D(1,1));0154 r2=sqrt(2*log(2)*D(2,2));0155 rotatedData=[r1*cos(theta); r2*sin(theta)];0156 origData=V*rotatedData;0157 xData=origData(1,:)+gaussianParam.mu(1);0158 yData=origData(2,:)+gaussianParam.mu(2);0159 <span class="keyword">end</span>0160 0161 <span class="comment">% ====== Self Demo ======</span>0162 <a name="_sub3" href="#_subfunctions" class="code">function selfdemo</a>0163 <a href="gmmTrainDemo2dCovType01.html" class="code" title="">gmmTrainDemo2dCovType01</a>;</pre></div><hr><address>Generated on Thu 30-Oct-2008 12:53:56 by <strong><a href="http://www.artefact.tk/software/matlab/m2html/">m2html</a></strong> © 2003</address></body></html>
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -