⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 Java实现的各种数学算法
💻 JAVA
字号:
// EM.java// This library is free software; you can redistribute it and/or// modify it under the terms of the GNU Library General Public// License as published by the Free Software Foundation; either// version 2 of the License, or (at your option) any later version.// // This library is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU// Library General Public License for more details.// // You should have received a copy of the GNU Library General Public// License along with this library; if not, write to the// Free Software Foundation, Inc., 59 Temple Place - Suite 330,// Boston, MA  02111-1307, USA.//// Primary author contact info:  www.idiom.com/~zilla  zilla@computer.orgpackage ZS;import java.awt.*;		// for algorithm debug onlyimport VisualNumerics.math.*;	// for testing inverse import zlib.*;/** * P(m|d) = P(d|m) P(m) / P(d) * * P(d|m) is the likelihood, it is the probability of the data assigned *   by this model. * P(m|d) is the posterior, it is the probability of *   this model given that the data has been observed. * * ndata - number of data items, each of nd dimensions * nmix - number of gaussians in the mixture model * @param data[ndata][nd] * @param covariances[nmix][nd][nd] * @param priors[nmix] * @param posteriors[ndata][nmix] */final public class EM{  final static int	verbose = 1;  public static void em(double[][] data,			int nsteps,			double[][] means,			double[][][] covariances,			double[] priors,			double[][] posteriors)    throws Exception  {    int ndata = data.length;    int nmix = means.length;    zliberror.assert(covariances.length == nmix, "em1");    zliberror.assert(priors.length == nmix, "em2");    zliberror.assert(posteriors.length == ndata, "em3");    zliberror.assert(posteriors[0].length == nmix, "em4");    init(data, means, covariances, priors);    for( int step=0; step < nsteps; step++ ) {      estep(data, means, covariances, priors, posteriors);      mstep(data, means, covariances, priors, posteriors);    }  } //em  /**   * initialize means, covariances, priors to reasonable starting values   */  public static void init(double[][] data,			  double[][] means,			  double[][][] covariances,			  double[] priors)  {    int ndata = data.length;    int nmix = means.length;    int nd = means[0].length;    zliberror.assert(covariances[0].length == nd);    zliberror.assert(covariances[0][0].length == nd);    double[][] bounds = new double[nd][2];    matrix.getNDBounds(data, bounds);    double[] del = new double[nd];    for( int id=0; id < nd; id++ ) {      if (verbose > 0)	System.out.println("data bounds: " + bounds[id][0]+".."+bounds[id][1]);      del[id] = bounds[id][1] - bounds[id][0];    }    for( int im=0; im < nmix; im++ ) {      for( int id=0; id < nd; id++ )  {	means[im][id] = bounds[id][0] + rnd.rndf() * del[id];      }      matrix.setIdentity(covariances[im]);      priors[im] = 1. / nd;    }  } //init  public static void estep(double[][] data,			   double[][] means,			   double[][][] covariances,			   double[] priors,			   double[][] posteriors)    throws Exception  {    int ndata = data.length;    int nmix = means.length;    zliberror.assert(posteriors.length == ndata, "estep0");    zliberror.assert(posteriors[0].length == nmix, "estep1");    for( int ip=0; ip < ndata; ip++ ) {      double norm = 0.f;      for( int im=0; im < nmix; im++ ) {	// TODO: pull the im loop outside ip, then can evaluate gscale	// once per model	// double gscale = gaussEvalScale(means[im], covariances[im]);	double g = zmath.gaussEval(data[ip], means[im], covariances[im]);	norm += priors[im] * g;      }      for( int im=0; im < nmix; im++ ) {	double g = zmath.gaussEval(data[ip], means[im], covariances[im]);	posteriors[ip][im] = (priors[im] * g) / norm;      }    }  } //estep  public static void mstep(double[][] data,			   double[][] means,			   double[][][] covariances,			   double[] priors,			   double[][] posteriors)    throws Exception  {    int ndata = data.length;    int nmix = means.length;    int nd = means[0].length;    // prior probability for a particular gaussian is its    // posterior probability summed over all data points.    for( int im=0; im < nmix; im++ ) {      double sum = 0.;      for( int ip=0; ip < ndata; ip++ ) {	sum += posteriors[ip][im];      }      priors[im] = sum / ndata;    }    // each mean is a weighted sum of the data points,    // weighted by the convex normalized posterior probability for that G    for( int im=0; im < nmix; im++ ) {      // zero this mean      for( int id=0; id < nd; id++ ) {	means[im][id] = 0.;      }      // accumulate      double sumposterior = 0.;      for( int ip=0; ip < ndata; ip++ ) {	double pw = posteriors[ip][im];	sumposterior += pw;	for( int id=0; id < nd; id++ ) {	  means[im][id] += (pw * data[ip][id]);	}      } //ip      double oosumposterior = 1. / sumposterior;      for( int id=0; id < nd; id++ ) {	means[im][id] *= oosumposterior;      } //id    } //im    // lastly, covariance is weighted sum of outer products    // of the de-meaned data, as weighted by the normalized  posteriors    for( int im=0; im < nmix; im++ ) {      double[][] cov = covariances[im];      double[] mean = means[im];      // zero this covariance      for( int id=0; id < nd; id++ ) {	for( int jd=0; jd < nd; jd++ ) {	  cov[id][jd] = 0.;	}      }      // accumulated weighted sum of outer product of de-meaned data      double sumposterior = 0.;      for( int ip=0; ip < ndata; ip++ ) {	double[] d = data[ip];	double pw = posteriors[ip][im];	sumposterior += pw;	for( int id=0; id < nd; id++ ) {	  for( int jd=0; jd < nd; jd++ ) {	    cov[id][jd] += (pw * (d[id] - mean[id]) * (d[jd] - mean[jd]));	  }	}      } //ip      // normalize the posterior weighting      double oosumposterior = 1. / sumposterior;      for( int id=0; id < nd; id++ ) {	for( int jd=0; jd < nd; jd++ ) {	  cov[id][jd] *= oosumposterior;	}      }      // and invert, because in e step it is ony used in the gaussEval      // TODO: yuck garbage      covariances[im] = DoubleMatrix.inverse(cov);    } //im  } //mstep  //----------------------------------------------------------------  // plotting code below.  //----------------------------------------------------------------  static final int 	RES = 400;  static final int 	RESo2 = RES/2;  static Frame		_f;  static Graphics	_g;  static void plot(double[][] data, double[][] means, double[][][] covariances)  {    int ndata = data.length;    int nmix = means.length;    int nd = means[0].length;    _g.clearRect(0,0,RES,RES);    _g.setColor(Color.blue);    for( int ip=0; ip < ndata; ip++ ) {      drawpt(data[ip][0],data[ip][1]);    }    _g.setColor(Color.red);    for( int im=0; im < nmix; im++ ) {      drawpt(means[im][0],means[im][1]);      // vis of covariance:      // cov is pos def, find set of x s.t. x C x = k      int npts = 21;      double k = 1.;      for( int ip=0; ip < npts; ip++ ) {	double angle = 2. * Math.PI * (ip / (double)(npts-1));	double x = Math.cos(angle);	double y = Math.sin(angle);	double r = findroot(covariances[im], x, y, k);	drawto(means[im][0] + r*x, means[im][1] + r*y, (ip==0));      }          } //nmix  } //plot  // x ( C11 x + C12 y ) + y ( C21 x + C22 y ) = k  // y/x = s  //  // y = x*s  // x ( C11 x + C12 x*s ) + x*s ( C21 x + C22 x*s ) = k  // C11 x^2 + C12 s x^2  +  C21 x^2 s + C22 x^2 s^2 = k  // x^2 ( C11 + C12 s + C21 s + C22 s^2 ) = k  // x = sqrt( k / ( C11 + C12 s + C21 s + C22 s^2 ) )  static double findroot(double[][] cov, double x, double y, double k)  {    double s = y / x;	// beware x 0    double den = cov[0][0] + cov[0][1]*s + cov[1][0]*s + cov[1][1]*s*s;    if (den < 0.) den = - den;    double x1 = Math.sqrt(k / den);    return x1 / Math.abs(x);	// ratio  } //findroot  static final int toscreen(double d)  {    return RESo2 + (int)(RESo2 * d);  }  static void drawpt(double dx, double dy)  {    int ix = toscreen(dx);    int iy = toscreen(dy);    _g.drawLine(ix-5, iy-5, ix+5, iy+5);    _g.drawLine(ix+5, iy-5, ix-5, iy+5);  }  static int _lx = 0;  static int _ly = 0;  static void drawto(double dx, double dy, boolean moveto)  {    int ix = toscreen(dx);    int iy = toscreen(dy);    if (!moveto) {      _g.drawLine(_lx, _ly, ix, iy);    }    _lx = ix;    _ly = iy;  }  static void drawpt(int ix, int iy)  {    _g.drawLine(ix-5, iy-5, ix+5, iy+5);    _g.drawLine(ix+5, iy-5, ix-5, iy+5);  }		     static void simpleTest()    throws Exception  {    int ndata = 100;    int nd = 2;    int nmix = 2;    double[][] data = new double[ndata][2];    for( int ip=0; ip < ndata; ip++ ) {      double angle = 2. * Math.PI * (ip / (double)ndata);      data[ip][0] = 0.5 * Math.cos(angle) + 0.05 * rnd.rndf11();      data[ip][1] = 0.5 * Math.sin(angle) + 0.05 * rnd.rndf11();    }    double[][] means = new double[nmix][nd];    double[][][] covariances = new double[nmix][nd][nd];    double[] priors = new double[nmix];    double[][] posteriors = new double[ndata][nmix];    int nsteps = 100;    init(data, means, covariances, priors);    for( int istep=0; istep < nsteps; istep++ ) {      estep(data, means, covariances, priors, posteriors);      mstep(data, means, covariances, priors, posteriors);      if (istep%3 == 0) {	plot(data, means, covariances);	zlib.more();      }    }  } //simpleTest  public static void main(String[] cmdline)  {    try {      _f = new Frame();      _f.setSize(RES,RES);      _f.setVisible(true);      _g = _f.getGraphics();      simpleTest();    }    catch(Exception x) { zliberror.die(x); }  } //main} //EM

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -