⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 irls.java

📁 Java实现的各种数学算法
💻 JAVA
字号:
// IRLS.java 
//
// Requires old visualnumerics math library - if you do not have this
// (it is no longer available), identify the matrix calls and substitute 
// another library, such as jama.  
// Also you will need to make these substitutions:
// -  change zliberror._assert to assert and delete the zlib import
//    (or just delete all the assert calls)
// -  change or delete the matrix.print calls.

// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Library General Public
// License as published by the Free Software Foundation; either
// version 2 of the License, or (at your option) any later version.
// 
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Library General Public License for more details.
// 
// You should have received a copy of the GNU Library General Public
// License along with this library; if not, write to the
// Free Software Foundation, Inc., 59 Temple Place - Suite 330,
// Boston, MA  02111-1307, USA.
//
// Primary author contact info: 
// j.p.lewis   www.idiom.com/~zilla  zilla@computer.org

package ZS;

import VisualNumerics.math.*;

import zlib.*;

public class IRLS
{
  final static int _verbose = 1;


  /**
   * At = A transpose
   * p = desired power, e.g. slightly above 1
   * we take At from the caller because the caller may have already
   * computed it
   *
   * The system should be overdetermined (A more rows than columns),
   * otherwise the solution will typically be zero
   * (and L1 of zero is no different than L2 zero).
   */
  public static void solve(final double[][] A, final double[][] At,
			   double[] x, final double[] b,
			   double p, int niter)
  {
    zliberror._assert(A.length == b.length);
    zliberror._assert(A[0].length == x.length);

    double[] x1 = l2solve(A,At, b);

    refine(A,At,x1,b, p,niter);

    int nc = x.length;
    for( int ic = 0; ic < nc; ic++ )  x[ic] = x1[ic];
  } //solve

  //----------------------------------------------------------------

  public static void refine(final double[][] A, final double[][] At,
			    double[] x, final double[] b,
			    double p, int niter)
  {

    int nr = b.length;
    int nc = x.length;

    double[] r = new double[nr];	// residual
    double[] w = new double[nr];	// weights

    double pm2 = (p-2.) / 2.;
    double eps = 100. * Double.MIN_VALUE;
    System.out.println("eps = "+eps);

    for( int iter = 0; iter < niter; iter++ ) {

      if (_verbose > 0) printerr(A,x,b);

      double[] Ax = DoubleMatrix.multiply(A,x);

      // residual
      boolean nonzero = false;
      for( int ir = 0; ir < nr; ir++ ) {

	r[ir] = b[ir] - Ax[ir];
	if (r[ir] != 0.) nonzero = true;

	double d = Math.pow( Math.abs(r[ir]) + eps, pm2 );
	w[ir] = d;
      }
      if (_verbose > 0) matrix.print("residual", r);
      if (_verbose > 0) matrix.print("weights", w);
      if (!nonzero) {
	System.out.println("IRLS returning - solution is exact");
	return;
      }

      double[][] AtW2 = matrix.diagMul2(At, w);
      double[][] AtW2A = DoubleMatrix.multiply(AtW2, A);
      double[] AtW2r = DoubleMatrix.multiply(AtW2, r);

      double[] dx = l2solve(AtW2A,AtW2A, AtW2r);
      zliberror._assert(dx.length == nc);
      for( int ic = 0; ic < nc; ic++ )  x[ic] += dx[ic];
      if (_verbose > 0) matrix.print("new x = ", x);
    }

    if (_verbose > 0) printerr(A,x,b);    
  } //refine

  //----------------------------------------------------------------

  public static double[] l2solve(final double[][] A, final double[][] At,
				 final double[] b)
  {
    try {
      if (_verbose > 0)
	System.out.println("  solving "+A.length+"x"+A[0].length);

      if (A.length == A[0].length) {
	return VisualNumerics.math.DoubleMatrix.solve(A, b);
      }
      else {
	double[][] AtA = DoubleMatrix.multiply(At, A);
	double[] Atb = DoubleMatrix.multiply(At, b);
	return VisualNumerics.math.DoubleMatrix.solve(AtA, Atb);
      }

    }
    catch(Exception ex) {
      zliberror.die(ex);
      //System.err.println(ex);
      //System.exit(1);
    }

    return null;
  } //l2solve

  //----------------------------------------------------------------

  static void printerr(double[][] A, double[] x, double[] b)
  {
    double[] Ax = DoubleMatrix.multiply(A,x);
    double sum2 = 0.;
    double sum1 = 0.;

    for( int i = 0; i < b.length; i++ ) {
      double r = b[i] - Ax[i];
      sum2 += (r*r);
      sum1 += Math.abs(r);
    }

    System.out.println(" irls  L1="+sum1+"  L2="+sum2);
  } //printerr

  //----------------------------------------------------------------

  /**
   * test on a line fit
   */
  public static void main(String[] cmdline)
  {
    int npts = 50;

    double a = 0.1;
    double b = 10.;
    

    double[][] A = new double[50][2];
    double[] y = new double[50];
    for( int i = 0; i < npts; i++ ) {
      double x = (100. * i) / npts;

      A[i][0] = x;
      A[i][1] = 1.;

      y[i] = a * x + b;
      y[i] += (10. * (2.0*Math.random()-1.0));
    }
    //matrix.print("A=",A);
    matrix.print("y=",y);

    double[][] At = DoubleMatrix.transpose(A);
    double[] u = new double[2];

    u = IRLS.l2solve(A,At, y);

    System.out.println("desired a,b="+a+","+b+
		       "    L2 recovered a,b="+u[0]+","+u[1]);

    // on this problem most of the reduction in L1 error happens in
    // the first ~5 iterations
    IRLS.solve(A,At, u, y, 1.2, 10);
    System.out.println("desired a,b="+a+","+b+
		       "    L1 recovered a,b="+u[0]+","+u[1]);
    System.exit(0);
  } //main


} //IRLS

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -