📄 em.java
字号:
// EM.java// This library is free software; you can redistribute it and/or// modify it under the terms of the GNU Library General Public// License as published by the Free Software Foundation; either// version 2 of the License, or (at your option) any later version.// // This library is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU// Library General Public License for more details.// // You should have received a copy of the GNU Library General Public// License along with this library; if not, write to the// Free Software Foundation, Inc., 59 Temple Place - Suite 330,// Boston, MA 02111-1307, USA.//// Primary author contact info: www.idiom.com/~zilla zilla@computer.orgpackage ZS;import java.awt.*; // for algorithm debug onlyimport VisualNumerics.math.*; // for testing inverse import zlib.*;/** * P(m|d) = P(d|m) P(m) / P(d) * * P(d|m) is the likelihood, it is the probability of the data assigned * by this model. * P(m|d) is the posterior, it is the probability of * this model given that the data has been observed. * * ndata - number of data items, each of nd dimensions * nmix - number of gaussians in the mixture model * @param data[ndata][nd] * @param covariances[nmix][nd][nd] * @param priors[nmix] * @param posteriors[ndata][nmix] */final public class EM{ final static int verbose = 1; public static void em(double[][] data, int nsteps, double[][] means, double[][][] covariances, double[] priors, double[][] posteriors) throws Exception { int ndata = data.length; int nmix = means.length; zliberror.assert(covariances.length == nmix, "em1"); zliberror.assert(priors.length == nmix, "em2"); zliberror.assert(posteriors.length == ndata, "em3"); zliberror.assert(posteriors[0].length == nmix, "em4"); init(data, means, covariances, priors); for( int step=0; step < nsteps; step++ ) { estep(data, means, covariances, priors, posteriors); mstep(data, means, covariances, priors, posteriors); } } //em /** * initialize means, covariances, priors to reasonable starting values */ public static void init(double[][] data, double[][] means, double[][][] covariances, double[] priors) { int ndata = data.length; int nmix = means.length; int nd = means[0].length; zliberror.assert(covariances[0].length == nd); zliberror.assert(covariances[0][0].length == nd); double[][] bounds = new double[nd][2]; matrix.getNDBounds(data, bounds); double[] del = new double[nd]; for( int id=0; id < nd; id++ ) { if (verbose > 0) System.out.println("data bounds: " + bounds[id][0]+".."+bounds[id][1]); del[id] = bounds[id][1] - bounds[id][0]; } for( int im=0; im < nmix; im++ ) { for( int id=0; id < nd; id++ ) { means[im][id] = bounds[id][0] + rnd.rndf() * del[id]; } matrix.setIdentity(covariances[im]); priors[im] = 1. / nd; } } //init public static void estep(double[][] data, double[][] means, double[][][] covariances, double[] priors, double[][] posteriors) throws Exception { int ndata = data.length; int nmix = means.length; zliberror.assert(posteriors.length == ndata, "estep0"); zliberror.assert(posteriors[0].length == nmix, "estep1"); for( int ip=0; ip < ndata; ip++ ) { double norm = 0.f; for( int im=0; im < nmix; im++ ) { // TODO: pull the im loop outside ip, then can evaluate gscale // once per model // double gscale = gaussEvalScale(means[im], covariances[im]); double g = zmath.gaussEval(data[ip], means[im], covariances[im]); norm += priors[im] * g; } for( int im=0; im < nmix; im++ ) { double g = zmath.gaussEval(data[ip], means[im], covariances[im]); posteriors[ip][im] = (priors[im] * g) / norm; } } } //estep public static void mstep(double[][] data, double[][] means, double[][][] covariances, double[] priors, double[][] posteriors) throws Exception { int ndata = data.length; int nmix = means.length; int nd = means[0].length; // prior probability for a particular gaussian is its // posterior probability summed over all data points. for( int im=0; im < nmix; im++ ) { double sum = 0.; for( int ip=0; ip < ndata; ip++ ) { sum += posteriors[ip][im]; } priors[im] = sum / ndata; } // each mean is a weighted sum of the data points, // weighted by the convex normalized posterior probability for that G for( int im=0; im < nmix; im++ ) { // zero this mean for( int id=0; id < nd; id++ ) { means[im][id] = 0.; } // accumulate double sumposterior = 0.; for( int ip=0; ip < ndata; ip++ ) { double pw = posteriors[ip][im]; sumposterior += pw; for( int id=0; id < nd; id++ ) { means[im][id] += (pw * data[ip][id]); } } //ip double oosumposterior = 1. / sumposterior; for( int id=0; id < nd; id++ ) { means[im][id] *= oosumposterior; } //id } //im // lastly, covariance is weighted sum of outer products // of the de-meaned data, as weighted by the normalized posteriors for( int im=0; im < nmix; im++ ) { double[][] cov = covariances[im]; double[] mean = means[im]; // zero this covariance for( int id=0; id < nd; id++ ) { for( int jd=0; jd < nd; jd++ ) { cov[id][jd] = 0.; } } // accumulated weighted sum of outer product of de-meaned data double sumposterior = 0.; for( int ip=0; ip < ndata; ip++ ) { double[] d = data[ip]; double pw = posteriors[ip][im]; sumposterior += pw; for( int id=0; id < nd; id++ ) { for( int jd=0; jd < nd; jd++ ) { cov[id][jd] += (pw * (d[id] - mean[id]) * (d[jd] - mean[jd])); } } } //ip // normalize the posterior weighting double oosumposterior = 1. / sumposterior; for( int id=0; id < nd; id++ ) { for( int jd=0; jd < nd; jd++ ) { cov[id][jd] *= oosumposterior; } } // and invert, because in e step it is ony used in the gaussEval // TODO: yuck garbage covariances[im] = DoubleMatrix.inverse(cov); } //im } //mstep //---------------------------------------------------------------- // plotting code below. //---------------------------------------------------------------- static final int RES = 400; static final int RESo2 = RES/2; static Frame _f; static Graphics _g; static void plot(double[][] data, double[][] means, double[][][] covariances) { int ndata = data.length; int nmix = means.length; int nd = means[0].length; _g.clearRect(0,0,RES,RES); _g.setColor(Color.blue); for( int ip=0; ip < ndata; ip++ ) { drawpt(data[ip][0],data[ip][1]); } _g.setColor(Color.red); for( int im=0; im < nmix; im++ ) { drawpt(means[im][0],means[im][1]); // vis of covariance: // cov is pos def, find set of x s.t. x C x = k int npts = 21; double k = 1.; for( int ip=0; ip < npts; ip++ ) { double angle = 2. * Math.PI * (ip / (double)(npts-1)); double x = Math.cos(angle); double y = Math.sin(angle); double r = findroot(covariances[im], x, y, k); drawto(means[im][0] + r*x, means[im][1] + r*y, (ip==0)); } } //nmix } //plot // x ( C11 x + C12 y ) + y ( C21 x + C22 y ) = k // y/x = s // // y = x*s // x ( C11 x + C12 x*s ) + x*s ( C21 x + C22 x*s ) = k // C11 x^2 + C12 s x^2 + C21 x^2 s + C22 x^2 s^2 = k // x^2 ( C11 + C12 s + C21 s + C22 s^2 ) = k // x = sqrt( k / ( C11 + C12 s + C21 s + C22 s^2 ) ) static double findroot(double[][] cov, double x, double y, double k) { double s = y / x; // beware x 0 double den = cov[0][0] + cov[0][1]*s + cov[1][0]*s + cov[1][1]*s*s; if (den < 0.) den = - den; double x1 = Math.sqrt(k / den); return x1 / Math.abs(x); // ratio } //findroot static final int toscreen(double d) { return RESo2 + (int)(RESo2 * d); } static void drawpt(double dx, double dy) { int ix = toscreen(dx); int iy = toscreen(dy); _g.drawLine(ix-5, iy-5, ix+5, iy+5); _g.drawLine(ix+5, iy-5, ix-5, iy+5); } static int _lx = 0; static int _ly = 0; static void drawto(double dx, double dy, boolean moveto) { int ix = toscreen(dx); int iy = toscreen(dy); if (!moveto) { _g.drawLine(_lx, _ly, ix, iy); } _lx = ix; _ly = iy; } static void drawpt(int ix, int iy) { _g.drawLine(ix-5, iy-5, ix+5, iy+5); _g.drawLine(ix+5, iy-5, ix-5, iy+5); } static void simpleTest() throws Exception { int ndata = 100; int nd = 2; int nmix = 2; double[][] data = new double[ndata][2]; for( int ip=0; ip < ndata; ip++ ) { double angle = 2. * Math.PI * (ip / (double)ndata); data[ip][0] = 0.5 * Math.cos(angle) + 0.05 * rnd.rndf11(); data[ip][1] = 0.5 * Math.sin(angle) + 0.05 * rnd.rndf11(); } double[][] means = new double[nmix][nd]; double[][][] covariances = new double[nmix][nd][nd]; double[] priors = new double[nmix]; double[][] posteriors = new double[ndata][nmix]; int nsteps = 100; init(data, means, covariances, priors); for( int istep=0; istep < nsteps; istep++ ) { estep(data, means, covariances, priors, posteriors); mstep(data, means, covariances, priors, posteriors); if (istep%3 == 0) { plot(data, means, covariances); zlib.more(); } } } //simpleTest public static void main(String[] cmdline) { try { _f = new Frame(); _f.setSize(RES,RES); _f.setVisible(true); _g = _f.getGraphics(); simpleTest(); } catch(Exception x) { zliberror.die(x); } } //main} //EM
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -