📄 cgminimizer.java
字号:
package edu.stanford.nlp.optimization;/** * Conjugate-gradient implementation based on the code in Numerical * Recipes in C. (See p. 423 and others.) As of now, it requires a * differentiable function (DiffFunction) as input. Equality * constraints are supported; inequality constraints may soon be * added. * * The basic way to use the minimizer is with a null constructor, then * the simple minimize method: * * <p><code>Minimizer cgm = new CGMinimizer();</code> * <br><code>DiffFunction df = new SomeDiffFunction();</code> * <br><code>double tol = 1e-4;</code> * <br><code>double[] initial = getInitialGuess();</code> * <br><code>double[] minimum = cgm.minimize(df,tol,initial);</code> * * @author <a href="mailto:klein@cs.stanford.edu">Dan Klein</a> * @version 1.0 * @since 1.0 */public class CGMinimizer implements Minimizer { private Function monitor; // = null; private static final int numToPrint = 5; private static final boolean simpleGD = false; private static final boolean checkSimpleGDConvergence = true; private static final boolean verbose = false; private boolean silent; private static final int ITMAX = 500; // overridden in dbrent(); made bigger private static final double EPS = 1.0e-30; private static final double TOL = 2.0e-15; private static final int resetFrequency = 10; double[] copyArray(double[] a) { double[] result = new double[a.length]; for(int i=0;i<a.length;i++) result[i]=a[i]; return result; } private static String arrayToString(double[] x) { return arrayToString(x, x.length); } private static String arrayToString(double[] x, int num) { StringBuffer sb = new StringBuffer("("); if (num > x.length) { num = x.length; } for (int j=0; j<num; j++) { sb.append(x[j]); if (j != x.length-1) { sb.append(", "); } } if (num < x.length) { sb.append("..."); } sb.append(")"); return sb.toString(); } private double fabs(double x) { if (x<0) return -1.0*x; return x; } private double fmax(double x, double y) { if (x<y) return y; return x; } private double fmin(double x, double y) { if (x>y) return y; return x; } private double sign(double x, double y) { if (y>=0) return fabs(x); return -1.0*fabs(x); } private double arrayMax(double[] x) { double max = Double.NEGATIVE_INFINITY; for (int i=0; i<x.length; i++) { if (max < x[i]) max = x[i]; } return max; } private int arrayArgMax(double[] x) { double max = Double.NEGATIVE_INFINITY; int index = -1; for (int i=0; i<x.length; i++) { if (max < x[i]) { max = x[i]; index = i; } } return index; } private double arrayMin(double[] x) { double min = Double.POSITIVE_INFINITY; for (int i=0; i<x.length; i++) { if (min > x[i]) min = x[i]; } return min; } private int arrayArgMin(double[] x) { double min = Double.POSITIVE_INFINITY; int index = -1; for (int i=0; i<x.length; i++) { if (min > x[i]) { min = x[i]; index = i; } } return index; } class OneDimDiffFunction { DiffFunction function; double[] initial; double[] direction; double[] tempVector; double[] vectorOf(double x) { for (int j=0; j<initial.length; j++) { tempVector[j] = initial[j]+x*direction[j]; } //System.err.println("Tmp "+arrayToString(tempVector,10)); //System.err.println("Dir "+arrayToString(direction,10)); return tempVector; } double valueAt(double x) { double v = function.valueAt(vectorOf(x)); return v; } double derivativeAt(double x) { double[] g = function.derivativeAt(vectorOf(x)); double d = 0.0; for (int j=0; j<g.length; j++) { d += g[j]*direction[j]; } return d; } double[] copyArray(double[] a) { double[] result = new double[a.length]; for(int i=0;i<a.length;i++) result[i]=a[i]; return result; } OneDimDiffFunction(DiffFunction function, double[] initial, double[] direction) { this.function = function; this.initial = copyArray(initial); this.direction = copyArray(direction); this.tempVector = new double[function.domainDimension()]; } } Triple mnbrak(Triple abc, OneDimDiffFunction function) { // constants double GOLD = 1.618034; double GLIMIT = 100; double TINY = 1.0e-20; // inputs double ax = abc.a; double fa = function.valueAt(ax); double bx = abc.b; double fb = function.valueAt(bx); double cx = abc.c; double fc = 0.0; if (fb > fa) { // swap double temp = fa; fa = fb; fb = temp; temp = ax; ax = bx; bx = temp; } // guess cx cx = bx+GOLD*(bx-ax); fc = function.valueAt(cx); // loop until we get a bracket while (fb > fc) { double r = (bx-ax)*(fb-fc); double q = (bx-cx)*(fb-fa); double u = bx-((bx-cx)*q-(bx-ax)*r)/(2.0*sign(fmax(fabs(q-r),TINY),q-r)); double fu = 0.0; double ulim = bx+GLIMIT*(cx-bx); if ((bx-u)*(u-cx)>0.0) { fu = function.valueAt(u); if (fu < fc) { //Ax = new Double(bx); //Bx = new Double(u); //Cx = new Double(cx); //System.err.println("\nReturning3: a="+bx+" ("+fb+") b="+u+"("+fu+") c="+cx+" ("+fc+")"); return new Triple(bx,u,cx); } else if (fu > fb) { //Cx = new Double(u); //Ax = new Double(ax); //Bx = new Double(bx); //System.err.println("\nReturning2: a="+ax+" ("+fa+") b="+bx+"("+fb+") c="+u+" ("+fu+")"); return new Triple(ax,bx,u); } u = cx+GOLD*(cx-bx); fu = function.valueAt(u); } else if ((cx-u)*(u-ulim) > 0.0) { fu = function.valueAt(u); if (fu < fc) { bx = cx; cx = u; u = cx+GOLD*(cx-bx); fb = fc; fc = fu; fu = function.valueAt(u); } } else if ((u-ulim)*(ulim-cx) >= 0.0) { u = ulim; fu = function.valueAt(u); } else { u = cx+GOLD*(cx-bx); fu = function.valueAt(u); } ax = bx; bx = cx; cx = u; fa = fb; fb = fc; fc = fu; } //System.err.println("\nReturning: a="+ax+" ("+fa+") b="+bx+"("+fb+") c="+cx+" ("+fc+")"); return new Triple(ax,bx,cx); } double dbrent(OneDimDiffFunction function, double ax, double bx, double cx) { // constants boolean dbVerbose = false; int ITMAX = 100; double TOL = 1.0e-4; double ZEPS = 1.0e-20;//1.0e-10; boolean ok1,ok2; double a,b,d=0.0,d1,d2,du,dv,dw,dx,e=0.0; double fu,fv,fw,fx,olde,tol1,tol2,u,u1,u2,v,w,x,xm; a = (ax < cx ? ax : cx); b = (ax > cx ? ax : cx); x=bx; v=bx; w=bx; fx = function.valueAt(x); fv=fx; fw=fx; dx = function.derivativeAt(x); dv=dx; dw=dx; for (int iteration=0; iteration<ITMAX; iteration++) { //System.err.println("dbrent "+iteration+" x "+x+" fx "+fx); xm = 0.5*(a+b); tol1 = TOL*fabs(x);//+ZEPS; tol2 = 2.0*tol1; if (fabs(x-xm) <= (tol2-0.5*(b-a))) { if (dbVerbose) System.err.println("dbrent returning because min is cornered "+a+" ("+function.valueAt(a)+") ~ "+x+" ("+fx+") "+b+" ("+function.valueAt(b)+")"); return x; } if (fabs(e) > tol1) { d1 = 2.0*(b-a); d2 = d1; if (dw != dx) d1 = (w-x)*dx/(dx-dw); if (dv != dx) d2 = (v-x)*dx/(dx-dv); u1 = x+d1; u2 = x+d2;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -