gradientascent.java
来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 119 行
JAVA
119 行
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. *//** @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */package edu.umass.cs.mallet.base.maximize;import edu.umass.cs.mallet.base.util.MalletLogger;import edu.umass.cs.mallet.base.types.MatrixOps;import java.util.logging.*;// Gradient Ascentpublic class GradientAscent implements Maximizer.ByGradient{ private double maxStep = 1.0; public LineMaximizer.ByGradient getLineMaximizer () { return lineMaximizer; } public void setLineMaximizer (LineMaximizer.ByGradient lineMaximizer) { this.lineMaximizer = lineMaximizer; } private static Logger logger = MalletLogger.getLogger(GradientAscent.class.getName()); public double getInitialStepSize () { return initialStepSize; } public void setInitialStepSize (double initialStepSize) { step = initialStepSize; } static final double initialStepSize = 0.2; double tolerance = 0.001; int maxIterations = 200; LineMaximizer.ByGradient lineMaximizer = new BackTrackLineSearch(); double stpmax = 100; // "eps" is a small number to recitify the special case of converging // to exactly zero function value final double eps = 1.0e-10; public GradientAscent () { } double step = initialStepSize; public double getStpmax () { return stpmax; } public void setStpmax (double stpmax) { this.stpmax = stpmax; } public boolean maximize (Maximizable.ByGradient maxable) { return maximize (maxable, maxIterations); } public boolean maximize (Maximizable.ByGradient maxable, int numIterations) { int iterations; double fret; double fp = maxable.getValue (); double[] xi = new double [maxable.getNumParameters()]; maxable.getValueGradient(xi); for (iterations = 0; iterations < numIterations; iterations++) { logger.info ("At iteration "+iterations+", cost = "+fp+", scaled = "+maxStep+" step = "+step+", gradient infty-norm = "+MatrixOps.infinityNorm (xi)); // Ensure step not too large double sum = MatrixOps.twoNorm (xi); if (sum > stpmax) { logger.info ("*** Step 2-norm "+sum+" greater than max "+stpmax+" Scaling..."); MatrixOps.timesEquals (xi,stpmax/sum); } step = lineMaximizer.maximize (maxable, xi, step); fret = maxable.getValue (); if (2.0*Math.abs(fret-fp) <= tolerance*(Math.abs(fret)+Math.abs(fp)+eps)) { logger.info ("Gradient Ascent: Value difference "+Math.abs(fret-fp)+" below " + "tolerance; saying converged."); return true; } fp = fret; maxable.getValueGradient(xi); } return false; } public void setMaxStepSize (double v) { maxStep = v; }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?