📄 gradientminimizer.java
字号:
import java.util.*;public class GradientMinimizer implements OneVarFn{ public static final int ALG_SD = 0; public static final int ALG_CGFR = 1; public static final int ALG_CGPR = 2; private int algo = ALG_SD; /* Algorithm to be used */ private int n; private Variable X[]; /* Expressions for the function, gradient, and Hessian in postfix */ private Vector fnPF; private Vector gradPF[]; private Vector hessPF[][]; /* What information do we have about the function? */ private boolean haveFunction; private boolean haveGradient; private boolean haveHessian; private double xpos[]; private double grad[]; private double ograd[]; private double h[]; private double xtmp[]; private double hess[][]; private int nIter; private int nFnEval; private OneVarMinimizer ovmz; private double fnVal; public GradientMinimizer( int dim ) { n = dim; X = new Variable[n]; int i; for( i = 0; i < dim; i++ ) X[i] = Variable.create( "x" + (i+1), 0.0 ); // The first variable is called x1 xpos = new double[n]; grad = new double[n]; ograd = new double[n]; h = new double[n]; xtmp = new double[n]; hess = new double[n][n]; gradPF = new Vector[n]; hessPF = new Vector[n][n]; ovmz = new OneVarMinimizer( this ); } public void setAlgorithm( int a ) { algo = a; } public boolean setFunction( String fnExpr ) { fnPF = XToken.compileToPostFix( fnExpr ); if( fnPF == null ) return false; haveFunction = true; return true; } public boolean setGradient( String gradExpr[] ) { int i; for( i = 0; i < n; i++ ) { gradPF[i] = XToken.compileToPostFix( gradExpr[i] ); if( gradPF[i] == null ) return false; } haveGradient = true; return true; } public boolean setHessian( String hessExpr[][] ) { int i, j; for( i = 0; i < n; i++ ) { for( j = 0; j < n; j++ ) { hessPF[i][j] = XToken.compileToPostFix( hessExpr[i][j] ); if( hessPF[i][j] == null ) return false; } } haveHessian = true; return true; } public double evalF( double x[] ) { int i; for( i = 0; i < n; i++ ) X[i].setValue( x[i] ); return XToken.eval( fnPF ); } public double eval( double alpha, Object user_data ) { int i; for( i = 0; i < n; i++ ) xtmp[i] = xpos[i] - alpha * h[i]; return evalF( xtmp ); } public void computeGradient( double x[] ) { if( !haveGradient ) return; int i; for( i = 0; i < n; i++ ) X[i].setValue( x[i] ); for( i = 0; i < n; i++ ) grad[i] = XToken.eval( gradPF[i] ); } public void computeHessian( double x[] ) { if( !haveHessian ) return; int i, j; for( i = 0; i < n; i++ ) X[i].setValue( x[i] ); for( i = 0; i < n; i++ ) for( j = 0; j < n; j++ ) hess[i][j] = XToken.eval( hessPF[i][j] ); } public void setInitialPoint( double x[] ) { int i; for( i = 0; i < n; i++ ) xpos[i] = x[i]; nIter = 0; } public void initCGSD( double x[] ) { setInitialPoint( x ); computeGradient( x ); int i; for( i = 0; i < n; i++ ) h[i] = grad[i]; fnVal = evalF( x ); } public void initNewton( double x[] ) { initCGSD( x ); computeHessian( x ); } public String dblArrayToString( double x[] ) { int i; String txt = new String( "(" ); int n = x.length; for( i = 0; i < n; i++ ) { if( i > 0 ) txt += ","; txt += " " + x[i]; } txt += " )"; return txt; } public double normOne( double x[] ) { int i, dim; double s, d; s = 0.0; dim = x.length; for( i = 0; i < dim; i++ ) s += Math.abs( x[i] ); return s; } public double dotProduct( double x[], double y[] ) { double sum=0.0; int i; for( i = 0; i < n; i++ ) sum += x[i] * y[i]; return sum; } public void debugState() { System.out.println( "" ); System.out.println( "Iteration:" + nIter ); System.out.println( "Position x:" + dblArrayToString( xpos ) ); System.out.println( "Function value f(x):" + fnVal ); System.out.println( "Gradient g:" + dblArrayToString( grad ) ); System.out.println( "Search direction h:" + dblArrayToString( h ) ); } public void iterCGSD( boolean doCG, boolean doPolakRibiere ) { double num, den, beta, alpha; int i; nIter++; /* If h is already too small there is nothing to do. */ if( getHOneNorm() < 1e-10 ) return; /* Go along the search direction ... */ if( !ovmz.minimize( 0.0, 1e-5, this, 1e-7, 1000 ) ) System.out.println( "NO CONVERGENCE" ); alpha = ovmz.xBest; /* ... and find where the function is minimized. */ /* Update new position */ for( i = 0; i < n; i++ ) xpos[i] -= alpha*h[i]; if( doCG ) { /* Conjugate Gradient */ /* Save gradient for Polak-Ribiere */ if( doPolakRibiere ) for( i = 0; i < n; i++ ) ograd[i] = grad[i]; den = dotProduct( grad, grad ); /* Compute new gradient */ computeGradient( xpos ); num = dotProduct( grad, grad ); if( doPolakRibiere ) num -= dotProduct( grad, ograd ); beta = num / den; /* Compute new search direction */ for( i = 0; i < n; i++ ) h[i] = grad[i] + beta * h[i]; } else { /* Steepest Descent */ /* Compute new gradient */ computeGradient( xpos ); /* h is same as g */ for( i = 0; i < n; i++ ) h[i] = grad[i]; } fnVal = evalF( xpos ); } public void iterSD() { iterCGSD( false, false ); } public void iterCGFletcherReeves() { iterCGSD( true, false ); } public void iterCGPolakRibiere() { iterCGSD( true, true ); } public void next() { switch( algo ) { case ALG_SD: iterSD(); break; case ALG_CGFR: iterCGFletcherReeves(); break; case ALG_CGPR: iterCGPolakRibiere(); break; } } public double getHOneNorm() { return normOne( h ); } public double[] getCurrentX() { return xpos; } public double[] getCurrentG() { return grad; } public double[] getCurrentH() { return h; } public double getCurrentFnVal() { return fnVal; } public int countIterations() { return nIter; } public static void main( String args[] ) { int dim = 2; String fExp; String grExp[] = new String[dim]; double x[] = new double[dim]; // RosenBrock x[0] = -1.2; x[1] = 1; fExp = "100*( x2 - x1*x1)^2 + (1-x1)^2"; grExp[0] = "-400*x1*(x2-x1*x1)-2*(1-x1)"; grExp[1] = "200*(x2-x1*x1)"; // Quadratic /* x[0] = 8; x[1] = 0.6; fExp = "x1*x1/100 + x2*x2"; grExp[0] = "x1/50"; grExp[1] = "2*x2"; */ GradientMinimizer gmz = new GradientMinimizer(dim); gmz.setFunction( fExp ); gmz.setGradient( grExp ); gmz.initCGSD( x ); gmz.debugState(); int i; for( i = 0; i< 10000; i++ ) { // gmz.iterCGPolakRibiere(); // gmz.iterCGFletcherReeves(); gmz.iterSD(); gmz.debugState(); if( gmz.getHOneNorm() < 1e-3 ) break; } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -