⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 limitedmemorybfgs.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. *//**    @author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a> */package edu.umass.cs.mallet.base.minimize;import edu.umass.cs.mallet.base.minimize.LineMinimizer;import edu.umass.cs.mallet.base.minimize.Minimizable;import edu.umass.cs.mallet.base.minimize.BackTrackLineSearch;import edu.umass.cs.mallet.base.types.Matrix;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.logging.*;import java.util.Vector;import java.util.Collection;import java.util.LinkedList;public class LimitedMemoryBFGS implements Minimizer.ByGradient{	final double initialStepSize = 0.01;	final int maxIterations = 300;	// xxx need a more principled stopping point//	final double tolerance = .001;//	final double tolerance = .00001;	final double tolerance = .0001;	final double gradientTolerance = .001;	//	final double eps = 1.0e-10;	final double eps = 1.0e-5;	// The number of corrections used in BFGS update	// ideally 3 <= m <= 7. Larger m means more cpu / memory	final int m = 4;	private static Logger logger = MalletLogger.getLogger(LimitedMemoryBFGS.class.getName());		/**	 * Line search function	 */	LineMinimizer lineMinimizer = new BackTrackLineSearch();	// State of search	Matrix g, oldg, direction, parameters, oldParameters;	// s is list of previous coefficient values	// y is list of previous gradient values	// rho is an intermediate value in calculating new direction	LinkedList s, y, rho;	Vector temp;	double step = 1.0;	int iterations;	public double getGradient()	{		return g.twoNorm();	}	public boolean minimize (Minimizable.ByGradient minable)	{		return minimize (minable, Integer.MAX_VALUE);	}	public boolean minimize (Minimizable.ByGradient minable, int numIterations)	{		double initialCost = minable.getCost();		logger.fine("Entering L-BFGS.minimize(). Initial Cost="+initialCost);				//first time through		if(g==null) {			logger.fine("First time through");			iterations = 0;			s = new LinkedList();			y = new LinkedList();			rho = new LinkedList();			temp = new Vector(m);			oldParameters = minable.getNewMatrix();			g = minable.getNewMatrix();			parameters = minable.getNewMatrix();			minable.getParameters (oldParameters);			oldg = (Matrix) oldParameters.cloneMatrix();			minable.getCostGradient(oldg);			direction = (Matrix) oldg.cloneMatrix(); 			if (direction.absNormalize() == 0) {				logger.info("L-BFGS initial gradient is zero; saying converged");				return true;			}			for(int i=0; i<m; i++)				temp.add(new Double(0.0));			direction.timesEquals(-1.0 * direction.twoNorm());//			direction.timesEquals(-1.0);			// make initial jump			logger.fine ("enter initial line search");			int unsuccessfulLineSearch = (int) lineMinimizer.minimize(minable, direction, step);			logger.fine ("exit initial line search");			minable.getParameters (parameters);			//minimize sets new gradient g						minable.getCostGradient(g);			if (unsuccessfulLineSearch == 1) {				logger.warning ("Line Search exited abnormally. Giving up L-BFGS.");				g = null;				return true;			}		}		for(int iterationCount = 0;	iterationCount < numIterations;	iterationCount++)	{			double cost = minable.getCost();			logger.fine("In L-BFGS, iteration="+iterationCount									+", cost="+cost);			// this returns 0 for component-wise -inf + inf 			oldParameters.equalsPlus(-1.0, parameters); // parameters - oldParameters			push(s, oldParameters);			oldg.equalsPlus(-1.0, g); // g - oldg			push(y, oldg); 			// si * yi			// xxx Q: What will this do when some of the parameters are -Infinity?			// xxx A: gamma and sy = NaN when a parameter is -Inf			// xxx made changes in DenseMatrix to enforce Inf - Inf = 0			double sy = ((Matrix)s.getLast()).dotProduct((Matrix)y.getLast());			//double sy = pdiff.dotProduct(gdiff);			// if didn't jump anywhere last time, sy == 0			// so, make it something small and positive to avoid div by 0//			if (sy <= 0 )//				sy =  eps; 			assert (sy > 0) : "sy: "+sy;			double gamma = sy /			    ( (Matrix) y.getLast()).dotProduct((Matrix)y.getLast());			//double gamma = sy /	gdiff.dotProduct(gdiff);											assert (gamma >= 0) : "Gamma:"+gamma;			push(rho, 1.0/sy);			direction.set(g);			//direction.timesEquals(-1.0);			for(int i = s.size() - 1; i >= 0; i--) {				double rhoTemp =  ((Double)rho.get(i)).doubleValue();				double factor = ((Matrix)s.get(i)).dotProduct(direction);				temp.set(i, new Double(rhoTemp * factor));				Matrix yTemp = (Matrix)y.get(i);				direction.plusEquals(yTemp, (-1.0 * rhoTemp * factor));				//assert (g.dotProduct(direction) > 0) : "alpha i="+i;			}			direction.timesEquals(gamma);			//assert (g.dotProduct(direction) > 0) : "gamma="+gamma;			for(int i =0; i < s.size(); i++) {				double tmp = (((Double)rho.get(i)).doubleValue()) *										 ((Matrix)y.get(i)).dotProduct(direction);				Matrix sTemp = (Matrix)s.get(i);				direction.plusEquals(sTemp,((Double)temp.get(i)).doubleValue() - tmp);				//assert (g.dotProduct(direction) > 0) : "beta i="+i;			}			oldg.set(g);			oldParameters.set(parameters);			direction.timesEquals(-1);			logger.info ("enter line search");			int unsuccessfulLineSearch = (int) lineMinimizer.minimize(minable, direction, step);			logger.info ("exit line search");			if (unsuccessfulLineSearch == 1) {				logger.warning ("Line Search exited abnormally. Giving up L-BFGS.");				return true;			}			//direction.timesEquals(-1);			minable.getParameters (parameters);			minable.getCostGradient(g);			double newCost = minable.getCost();			// Test for terminations			if(2.0*Math.abs(newCost-cost) <= tolerance*(Math.abs(newCost)+Math.abs(cost) + eps)){				logger.info("Exiting L-BFGS on termination #1:\ncost difference below tolerance");				return true;			}//			Matrix xi = minable.getNewMatrix();//			double xiInfinityNorm = xi.infinityNorm();//			// xxx AKM: Here "xi" is always a fresh, zero-filled matrix, so the following would have always been true!//			// Culotta was: if(xiInfinityNorm < tolerance) {//			if(false && xiInfinityNorm < tolerance) {//				logger.info("Exiting L-BFGS on termination #2: \ninfinityNorm="+xiInfinityNorm+" < tolerance="+tolerance);//				return true;//			}			double gg = g.twoNorm();			if(gg < gradientTolerance) {				logger.info("Exiting L-BFGS on termination #2: \ngradient="+gg+" < "+gradientTolerance);				return true;			}																	 			if(gg == 0.0) {				logger.info("Exiting L-BFGS on termination #3: \ngradient==0.0");				return true;			}// xxx seems to give same stopping pt as termination #1//			if((Math.abs(newCost - cost) / newCost) < tolerance) {//				logger.info("Exiting L-BFGS on termination #4:\nnormalized cost difference below tolerance");//				return true;//			}			logger.fine("Gradient = "+gg);			iterations++;			if (iterations > maxIterations) {				System.err.println("Too many iterations in L-BFGS.java. Continuing with current parameters.");				return true;        //throw new IllegalStateException ("Too many iterations.");			}		}		return false;	}		/**	 * Pushes a new object onto the queue l	 * @param l linked list queue of Matrix obj's	 * @param toadd matrix to push onto queue	 */	private void push(LinkedList l, Matrix toadd) {		assert(l.size() <= m);		if(l.size() == m) {			// remove oldest matrix and add newset to end of list.			// to make this more efficient, actually overwrite			// memory of oldest matrix			// this overwrites the oldest matrix			Matrix last = (Matrix) l.get(0);			last.set(toadd);			Object ptr = last;			// this readjusts the pointers in the list			for(int i=0; i<l.size()-1; i++) 				l.set(i, (Matrix)l.get(i+1));						l.set(m-1, ptr);		}// xxx this way was highly inefficient//		if(l.size() == m) { //pop old matrix and add new//			l.removeFirst(); //			l.addLast(toadd.cloneMatrix());//		}		else 			l.addLast(toadd.cloneMatrix());				// print out for debug	}	  /**	 * Pushes a new object onto the queue l	 * @param l linked list queue of Double obj's	 * @param toadd double value to push onto queue	 */	private void push(LinkedList l, double toadd) {		assert(l.size() <= m);		if(l.size() == m) { //pop old double and add new			l.removeFirst(); 			l.addLast(new Double(toadd));		}		else 			l.addLast(new Double(toadd));	}	}		

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -