📄 irls.java
字号:
// IRLS.java
//
// Requires old visualnumerics math library - if you do not have this
// (it is no longer available), identify the matrix calls and substitute
// another library, such as jama.
// Also you will need to make these substitutions:
// - change zliberror._assert to assert and delete the zlib import
// (or just delete all the assert calls)
// - change or delete the matrix.print calls.
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Library General Public
// License as published by the Free Software Foundation; either
// version 2 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Library General Public License for more details.
//
// You should have received a copy of the GNU Library General Public
// License along with this library; if not, write to the
// Free Software Foundation, Inc., 59 Temple Place - Suite 330,
// Boston, MA 02111-1307, USA.
//
// Primary author contact info:
// j.p.lewis www.idiom.com/~zilla zilla@computer.org
package ZS;
import VisualNumerics.math.*;
import zlib.*;
public class IRLS
{
final static int _verbose = 1;
/**
* At = A transpose
* p = desired power, e.g. slightly above 1
* we take At from the caller because the caller may have already
* computed it
*
* The system should be overdetermined (A more rows than columns),
* otherwise the solution will typically be zero
* (and L1 of zero is no different than L2 zero).
*/
public static void solve(final double[][] A, final double[][] At,
double[] x, final double[] b,
double p, int niter)
{
zliberror._assert(A.length == b.length);
zliberror._assert(A[0].length == x.length);
double[] x1 = l2solve(A,At, b);
refine(A,At,x1,b, p,niter);
int nc = x.length;
for( int ic = 0; ic < nc; ic++ ) x[ic] = x1[ic];
} //solve
//----------------------------------------------------------------
public static void refine(final double[][] A, final double[][] At,
double[] x, final double[] b,
double p, int niter)
{
int nr = b.length;
int nc = x.length;
double[] r = new double[nr]; // residual
double[] w = new double[nr]; // weights
double pm2 = (p-2.) / 2.;
double eps = 100. * Double.MIN_VALUE;
System.out.println("eps = "+eps);
for( int iter = 0; iter < niter; iter++ ) {
if (_verbose > 0) printerr(A,x,b);
double[] Ax = DoubleMatrix.multiply(A,x);
// residual
boolean nonzero = false;
for( int ir = 0; ir < nr; ir++ ) {
r[ir] = b[ir] - Ax[ir];
if (r[ir] != 0.) nonzero = true;
double d = Math.pow( Math.abs(r[ir]) + eps, pm2 );
w[ir] = d;
}
if (_verbose > 0) matrix.print("residual", r);
if (_verbose > 0) matrix.print("weights", w);
if (!nonzero) {
System.out.println("IRLS returning - solution is exact");
return;
}
double[][] AtW2 = matrix.diagMul2(At, w);
double[][] AtW2A = DoubleMatrix.multiply(AtW2, A);
double[] AtW2r = DoubleMatrix.multiply(AtW2, r);
double[] dx = l2solve(AtW2A,AtW2A, AtW2r);
zliberror._assert(dx.length == nc);
for( int ic = 0; ic < nc; ic++ ) x[ic] += dx[ic];
if (_verbose > 0) matrix.print("new x = ", x);
}
if (_verbose > 0) printerr(A,x,b);
} //refine
//----------------------------------------------------------------
public static double[] l2solve(final double[][] A, final double[][] At,
final double[] b)
{
try {
if (_verbose > 0)
System.out.println(" solving "+A.length+"x"+A[0].length);
if (A.length == A[0].length) {
return VisualNumerics.math.DoubleMatrix.solve(A, b);
}
else {
double[][] AtA = DoubleMatrix.multiply(At, A);
double[] Atb = DoubleMatrix.multiply(At, b);
return VisualNumerics.math.DoubleMatrix.solve(AtA, Atb);
}
}
catch(Exception ex) {
zliberror.die(ex);
//System.err.println(ex);
//System.exit(1);
}
return null;
} //l2solve
//----------------------------------------------------------------
static void printerr(double[][] A, double[] x, double[] b)
{
double[] Ax = DoubleMatrix.multiply(A,x);
double sum2 = 0.;
double sum1 = 0.;
for( int i = 0; i < b.length; i++ ) {
double r = b[i] - Ax[i];
sum2 += (r*r);
sum1 += Math.abs(r);
}
System.out.println(" irls L1="+sum1+" L2="+sum2);
} //printerr
//----------------------------------------------------------------
/**
* test on a line fit
*/
public static void main(String[] cmdline)
{
int npts = 50;
double a = 0.1;
double b = 10.;
double[][] A = new double[50][2];
double[] y = new double[50];
for( int i = 0; i < npts; i++ ) {
double x = (100. * i) / npts;
A[i][0] = x;
A[i][1] = 1.;
y[i] = a * x + b;
y[i] += (10. * (2.0*Math.random()-1.0));
}
//matrix.print("A=",A);
matrix.print("y=",y);
double[][] At = DoubleMatrix.transpose(A);
double[] u = new double[2];
u = IRLS.l2solve(A,At, y);
System.out.println("desired a,b="+a+","+b+
" L2 recovered a,b="+u[0]+","+u[1]);
// on this problem most of the reduction in L1 error happens in
// the first ~5 iterations
IRLS.solve(A,At, u, y, 1.2, 10);
System.out.println("desired a,b="+a+","+b+
" L1 recovered a,b="+u[0]+","+u[1]);
System.exit(0);
} //main
} //IRLS
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -