📄 lm.java
字号:
double[] s = new double[npts]; for( int i = 0; i < npts; i++ ) { x[i][0] = (double)i / npts; y[i] = val(x[i], a); s[i] = 1.; } Object[] o = new Object[4]; o[0] = x; o[1] = a; o[2] = y; o[3] = s; return o; } //test } //SineTest //---------------------------------------------------------------- /** * quadratic (p-o)'S'S(p-o) * solve for o, S * S is a single scale factor */ static class LMQuadTest implements LMfunc { public double val(double[] x, double[] a) { assert a.length == 3; assert x.length == 2; double ox = a[0]; double oy = a[1]; double s = a[2]; double sdx = s*(x[0] - ox); double sdy = s*(x[1] - oy); return sdx*sdx + sdy*sdy; } //val /** * z = (p-o)'S'S(p-o) * dz/dp = 2S'S(p-o) * * z = (s*(px-ox))^2 + (s*(py-oy))^2 * dz/dox = -2(s*(px-ox))*s * dz/ds = 2*s*[(px-ox)^2 + (py-oy)^2] * z = (s*dx)^2 + (s*dy)^2 * dz/ds = 2(s*dx)*dx + 2(s*dy)*dy */ public double grad(double[] x, double[] a, int a_k) { assert a.length == 3; assert x.length == 2; assert a_k < 3: "a_k="+a_k; double ox = a[0]; double oy = a[1]; double s = a[2]; double dx = (x[0] - ox); double dy = (x[1] - oy); if (a_k == 0) return -2.*s*s*dx; else if (a_k == 1) return -2.*s*s*dy; else return 2.*s*(dx*dx + dy*dy); } //grad public double[] initial() { double[] a = new double[3]; a[0] = 0.05; a[1] = 0.1; a[2] = 1.0; return a; } //initial public Object[] testdata() { Object[] o = new Object[4]; int npts = 25; double[][] x = new double[npts][2]; double[] y = new double[npts]; double[] s = new double[npts]; double[] a = new double[3]; a[0] = 0.; a[1] = 0.; a[2] = 0.9; int i = 0; for( int r = -2; r <= 2; r++ ) { for( int c = -2; c <= 2; c++ ) { x[i][0] = c; x[i][1] = r; y[i] = val(x[i], a); System.out.println("Quad "+c+","+r+" -> "+y[i]); s[i] = 1.; i++; } } System.out.print("quad x= "); (new Matrix(x)).print(10, 2); System.out.print("quad y= "); (new Matrix(y,npts)).print(10, 2); o[0] = x; o[1] = a; o[2] = y; o[3] = s; return o; } //testdata } //LMQuadTest //---------------------------------------------------------------- /** * Replicate the example in NR, fit a sum of Gaussians to data. * y(x) = \sum B_k exp(-((x - E_k) / G_k)^2) * minimize chisq = \sum { y[j] - \sum B_k exp(-((x_j - E_k) / G_k)^2) }^2 * * B_k, E_k, G_k are stored in that order * * Works, results are close to those from the NR example code. */ static class LMGaussTest implements LMfunc { static double SPREAD = 0.001; // noise variance public double val(double[] x, double[] a) { assert x.length == 1; assert (a.length%3) == 0; int K = a.length / 3; int i = 0; double y = 0.; for( int j = 0; j < K; j++ ) { double arg = (x[0] - a[i+1]) / a[i+2]; double ex = Math.exp(- arg*arg); y += (a[i] * ex); i += 3; } return y; } //val /** * <pre> * y(x) = \sum B_k exp(-((x - E_k) / G_k)^2) * arg = (x-E_k)/G_k * ex = exp(-arg*arg) * fac = B_k * ex * 2 * arg * * d/dB_k = exp(-((x - E_k) / G_k)^2) * * d/dE_k = B_k exp(-((x - E_k) / G_k)^2) . -2((x - E_k) / G_k) . -1/G_k * = 2 * B_k * ex * arg / G_k * d/E_k[-((x - E_k) / G_k)^2] = -2((x - E_k) / G_k) d/dE_k[(x-E_k)/G_k] * d/dE_k[(x-E_k)/G_k] = -1/G_k * * d/G_k = B_k exp(-((x - E_k) / G_k)^2) . -2((x - E_k) / G_k) . -(x-E_k)/G_k^2 * = B_k ex -2 arg -arg / G_k * = fac arg / G_k * d/dx[1/x] = d/dx[x^-1] = -x[x^-2] */ public double grad(double[] x, double[] a, int a_k) { assert x.length == 1; // i - index one of the K Gaussians int i = 3*(a_k / 3); double arg = (x[0] - a[i+1]) / a[i+2]; double ex = Math.exp(- arg*arg); double fac = a[i] * ex * 2. * arg; if (a_k == i) return ex; else if (a_k == (i+1)) { return fac / a[i+2]; } else if (a_k == (i+2)) { return fac * arg / a[i+2]; } else { System.err.println("bad a_k"); return 1.; } } //grad public double[] initial() { double[] a = new double[6]; a[0] = 4.5; a[1] = 2.2; a[2] = 2.8; a[3] = 2.5; a[4] = 4.9; a[5] = 2.8; return a; } //initial public Object[] testdata() { Object[] o = new Object[4]; int npts = 100; double[][] x = new double[npts][1]; double[] y = new double[npts]; double[] s = new double[npts]; double[] a = new double[6]; a[0] = 5.0; // values returned by initial a[1] = 2.0; // should be fairly close to these a[2] = 3.0; a[3] = 2.0; a[4] = 5.0; a[5] = 3.0; for( int i = 0; i < npts; i++ ) { x[i][0] = 0.1*(i+1); // NR always counts from 1 y[i] = val(x[i], a); s[i] = SPREAD * y[i]; System.out.println(i+": x,y= "+x[i][0]+", "+y[i]); } o[0] = x; o[1] = a; o[2] = y; o[3] = s; return o; } //testdata } //LMGaussTest //---------------------------------------------------------------- // test program public static void main(String[] cmdline) { LMfunc f = new LMQuadTest(); //LMfunc f = new LMSineTest(); // works //LMfunc f = new LMGaussTest(); // works double[] aguess = f.initial(); Object[] test = f.testdata(); double[][] x = (double[][])test[0]; double[] areal = (double[])test[1]; double[] y = (double[])test[2]; double[] s = (double[])test[3]; boolean[] vary = new boolean[aguess.length]; for( int i = 0; i < aguess.length; i++ ) vary[i] = true; assert aguess.length == areal.length; try { solve( x, aguess, y, s, vary, f, 0.001, 0.01, 100, 2); } catch(Exception ex) { System.err.println("Exception caught: " + ex.getMessage()); System.exit(1); } System.out.print("desired solution "); (new Matrix(areal, areal.length)).print(10, 2); System.exit(0); } //main} //LM
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -