⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lm.java

📁 Java实现的各种数学算法
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
// levenberg-marquardt in java //// To use this, implement the functions in the LMfunc interface.//// This library uses simple matrix routines from the JAMA java matrix package,// which is in the public domain.  Reference://    http://math.nist.gov/javanumerics/jama/// (JAMA has a matrix object class.  An earlier library JNL, which is no longer// available, represented matrices as low-level arrays.  Several years // ago the performance of JNL matrix code was better than that of JAMA,// though improvements in java compilers may have fixed this by now.)//// One further recommendation would be to use an inverse based// on Choleski decomposition, which is easy to implement and// suitable for the symmetric inverse required here.  There is a choleski// routine at idiom.com/~zilla.//// If you make an improved version, please consider adding your// name to it ("modified by ...") and send it back to me// (and put it on the web).//// ----------------------------------------------------------------// // This library is free software; you can redistribute it and/or// modify it under the terms of the GNU Library General Public// License as published by the Free Software Foundation; either// version 2 of the License, or (at your option) any later version.// // This library is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU// Library General Public License for more details.// // You should have received a copy of the GNU Library General Public// License along with this library; if not, write to the// Free Software Foundation, Inc., 59 Temple Place - Suite 330,// Boston, MA  02111-1307, USA.//// initial author contact info:  // jplewis  www.idiom.com/~zilla  zilla # computer.org,   #=at//// Improvements by:// dscherba  www.ncsa.uiuc.edu/~dscherba  // Jonathan Jackson   j.jackson # ucl.ac.ukpackage ZS.Solve;// see comment aboveimport Jama.*;/** * Levenberg-Marquardt, implemented from the general description * in Numerical Recipes (NR), then tweaked slightly to mostly * match the results of their code. * Use for nonlinear least squares assuming Gaussian errors. * * TODO this holds some parameters fixed by simply not updating them. * this may be ok if the number if fixed parameters is small, * but if the number of varying parameters is larger it would * be more efficient to make a smaller hessian involving only * the variables. * * The NR code assumes a statistical context, e.g. returns * covariance of parameter errors; we do not do this. */public final class LM{  /**   * calculate the current sum-squared-error   * (Chi-squared is the distribution of squared Gaussian errors,   * thus the name)   */  static double chiSquared(double[][] x, double[] a, double[] y, double[] s, 			   LMfunc f)  {    int npts = y.length;    double sum = 0.;    for( int i = 0; i < npts; i++ ) {      double d = y[i] - f.val(x[i], a);      d = d / s[i];      sum = sum + (d*d);    }    return sum;  } //chiSquared  /**   * Minimize E = sum {(y[k] - f(x[k],a)) / s[k]}^2   * The individual errors are optionally scaled by s[k].   * Note that LMfunc implements the value and gradient of f(x,a),   * NOT the value and gradient of E with respect to a!   *    * @param x array of domain points, each may be multidimensional   * @param y corresponding array of values   * @param a the parameters/state of the model   * @param vary false to indicate the corresponding a[k] is to be held fixed   * @param s2 sigma^2 for point i   * @param lambda blend between steepest descent (lambda high) and   *	jump to bottom of quadratic (lambda zero).   * 	Start with 0.001.   * @param termepsilon termination accuracy (0.01)   * @param maxiter	stop and return after this many iterations if not done   * @param verbose	set to zero (no prints), 1, 2   *   * @return the new lambda for future iterations.   *  Can use this and maxiter to interleave the LM descent with some other   *  task, setting maxiter to something small.   */  public static double solve(double[][] x, double[] a, double[] y, double[] s,			     boolean[] vary, LMfunc f,			     double lambda, double termepsilon, int maxiter,			     int verbose)    throws Exception  {    int npts = y.length;    int nparm = a.length;    assert s.length == npts;    assert x.length == npts;    if (verbose > 0) {      System.out.print("solve x["+x.length+"]["+x[0].length+"]" );      System.out.print(" a["+a.length+"]");      System.out.println(" y["+y.length+"]");    }    double e0 = chiSquared(x, a, y, s, f);    //double lambda = 0.001;    boolean done = false;    // g = gradient, H = hessian, d = step to minimum    // H d = -g, solve for d    double[][] H = new double[nparm][nparm];    double[] g = new double[nparm];    //double[] d = new double[nparm];    double[] oos2 = new double[s.length];    for( int i = 0; i < npts; i++ )  oos2[i] = 1./(s[i]*s[i]);    int iter = 0;    int term = 0;	// termination count test    do {      ++iter;      // hessian approximation      for( int r = 0; r < nparm; r++ ) {	for( int c = 0; c < nparm; c++ ) {	  for( int i = 0; i < npts; i++ ) {	    if (i == 0) H[r][c] = 0.;	    double[] xi = x[i];	    H[r][c] += (oos2[i] * f.grad(xi, a, r) * f.grad(xi, a, c));	  }  //npts	} //c      } //r      // boost diagonal towards gradient descent      for( int r = 0; r < nparm; r++ )	H[r][r] *= (1. + lambda);      // gradient      for( int r = 0; r < nparm; r++ ) {	for( int i = 0; i < npts; i++ ) {	  if (i == 0) g[r] = 0.;	  double[] xi = x[i];	  g[r] += (oos2[i] * (y[i]-f.val(xi,a)) * f.grad(xi, a, r));	}      } //npts      // scale (for consistency with NR, not necessary)      if (false) {	for( int r = 0; r < nparm; r++ ) {	  g[r] = -0.5 * g[r];	  for( int c = 0; c < nparm; c++ ) {	    H[r][c] *= 0.5;	  }	}      }      // solve H d = -g, evaluate error at new location      //double[] d = DoubleMatrix.solve(H, g);      double[] d = (new Matrix(H)).lu().solve(new Matrix(g, nparm)).getRowPackedCopy();      //double[] na = DoubleVector.add(a, d);      double[] na = (new Matrix(a, nparm)).plus(new Matrix(d, nparm)).getRowPackedCopy();      double e1 = chiSquared(x, na, y, s, f);      if (verbose > 0) {	System.out.println("\n\niteration "+iter+" lambda = "+lambda);	System.out.print("a = ");        (new Matrix(a, nparm)).print(10, 2);	if (verbose > 1) {          System.out.print("H = ");           (new Matrix(H)).print(10, 2);          System.out.print("g = ");           (new Matrix(g, nparm)).print(10, 2);          System.out.print("d = ");           (new Matrix(d, nparm)).print(10, 2);	}	System.out.print("e0 = " + e0 + ": ");	System.out.print("moved from ");        (new Matrix(a, nparm)).print(10, 2);	System.out.print("e1 = " + e1 + ": ");	if (e1 < e0) {	  System.out.print("to ");          (new Matrix(na, nparm)).print(10, 2);	}	else {	  System.out.println("move rejected");	}      }      // termination test (slightly different than NR)      if (Math.abs(e1-e0) > termepsilon) {	term = 0;      }      else {	term++;	if (term == 4) {	  System.out.println("terminating after " + iter + " iterations");	  done = true;	}      }      if (iter >= maxiter) done = true;      // in the C++ version, found that changing this to e1 >= e0      // was not a good idea.  See comment there.      //      if (e1 > e0 || Double.isNaN(e1)) { // new location worse than before	lambda *= 10.;      }      else {		// new location better, accept new parameters	lambda *= 0.1;	e0 = e1;	// simply assigning a = na will not get results copied back to caller	for( int i = 0; i < nparm; i++ ) {	  if (vary[i]) a[i] = na[i];	}      }    } while(!done);    return lambda;  } //solve  //----------------------------------------------------------------  /**   * solve for phase, amplitude and frequency of a sinusoid   */  static class LMSineTest implements LMfunc  {    static final int	PHASE = 0;    static final int	AMP = 1;    static final int	FREQ = 2;    public double[] initial()    {      double[] a = new double[3];      a[PHASE] = 0.;      a[AMP] = 1.;      a[FREQ] = 1.;      return a;    } //initial    public double val(double[] x, double[] a)    {      return a[AMP] * Math.sin(a[FREQ]*x[0] + a[PHASE]);    } //val    public double grad(double[] x, double[] a, int a_k)    {      if (a_k == AMP)	return Math.sin(a[FREQ]*x[0] + a[PHASE]);      else if (a_k == FREQ)	return a[AMP] * Math.cos(a[FREQ]*x[0] + a[PHASE]) * x[0];      else if (a_k == PHASE)	return a[AMP] * Math.cos(a[FREQ]*x[0] + a[PHASE]);      else {	assert false;	return 0.;      }    } //grad    public Object[] testdata() {      double[] a = new double[3];      a[PHASE] = 0.111;      a[AMP] = 1.222;      a[FREQ] = 1.333;      int npts = 10;      double[][] x = new double[npts][1];      double[] y = new double[npts];

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -