⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gistrainer.java

📁 最大熵模型源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/////////////////////////////////////////////////////////////////////////////// Copyright (C) 2001 Jason Baldridge and Gann Bierner//// This library is free software; you can redistribute it and/or// modify it under the terms of the GNU Lesser General Public// License as published by the Free Software Foundation; either// version 2.1 of the License, or (at your option) any later version.//// This library is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the// GNU General Public License for more details.//// You should have received a copy of the GNU Lesser General Public// License along with this program; if not, write to the Free Software// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.//////////////////////////////////////////////////////////////////////////////   package opennlp.maxent;import gnu.trove.*;import java.io.*;import java.util.*;import java.util.zip.*;/** * An implementation of Generalized Iterative Scaling.  The reference paper * for this implementation was Adwait Ratnaparkhi's tech report at the * University of Pennsylvania's Institute for Research in Cognitive Science, * and is available at <a href ="ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z"><code>ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z</code></a>.  * * @author  Jason Baldridge * @version $Revision: 1.9 $, $Date: 2003/01/08 15:44:47 $ */class GISTrainer {    // This can improve model accuracy, though training will potentially take    // longer and use more memory.  Model size will also be larger.  Initial    // testing indicates improvements for models built on small data sets and    // few outcomes, but performance degradation for those with large data    // sets and lots of outcomes.    private boolean _simpleSmoothing = false;    // If we are using smoothing, this is used as the "number" of    // times we want the trainer to imagine that it saw a feature that it    // actually didn't see.  Defaulted to 0.1.    private double _smoothingObservation = 0.1;        private boolean printMessages = false;      private int numTokens;   // # of event tokens    private int numPreds;    // # of predicates    private int numOutcomes; // # of outcomes    private int TID;         // global index variable for Tokens    private int PID;         // global index variable for Predicates        private int OID;         // global index variable for Outcomes    // a global variable for adding probabilities in an array    private double SUM;     // records the array of predicates seen in each event    private int[][] contexts;     // records the array of outcomes seen in each event    private int[] outcomes;     // records the num of times an event has been seen, paired to    // int[][] contexts    private int[] numTimesEventsSeen;    // stores the String names of the outcomes.  The GIS only tracks outcomes    // as ints, and so this array is needed to save the model to disk and    // thereby allow users to know what the outcome was in human    // understandable terms.    private String[] outcomeLabels;    // stores the String names of the predicates. The GIS only tracks    // predicates as ints, and so this array is needed to save the model to    // disk and thereby allow users to know what the outcome was in human    // understandable terms.    private String[] predLabels;    // stores the observed expections of each of the events    private TIntDoubleHashMap[] observedExpects;    // stores the estimated parameter value of each predicate during iteration    private TIntDoubleHashMap[] params;    // stores the modifiers of the parameter values, paired to params    private TIntDoubleHashMap[] modifiers;    // a helper object for storing predicate indexes    private int[] predkeys;     // GIS constant number of feattures fired    private int constant;    // stores inverse of constant after it is determined    private double constantInverse;    // the correction parameter of the model    private double correctionParam;     // observed expectation of correction feature    private double cfObservedExpect;    // a global variable to help compute the amount to modify the correction    // parameter    private double CFMOD;    private final double NEAR_ZERO = 0.01;    private final double LLThreshold = 0.0001;    // Stores the output of the current model on a single event durring    // training.  This we be reset for every event for every itteration.    double[] modelDistribution;    // Stores the number of features that get fired per event    int[] numfeats;    // initial probability for all outcomes.    double iprob;    // make all values in an TIntDoubleHashMap return to 0.0    private TDoubleFunction backToZeros =        new TDoubleFunction() {                public double execute(double arg) { return 0.0; }            };    // update the modifiers based on the modelDistribution for this event values    private TIntDoubleProcedure updateModifiers =        new TIntDoubleProcedure() {                public boolean execute(int oid, double arg) {                    modifiers[PID].put(oid,                                       arg                                       + (modelDistribution[oid]                                          * numTimesEventsSeen[TID]));                    return true;                }            };    // update the params based on the newly computed modifiers    private TIntDoubleProcedure updateParams =        new TIntDoubleProcedure() {                public boolean execute(int oid, double arg) {                    params[PID].put(oid,                                    arg +(observedExpects[PID].get(oid)					  - Math.log(modifiers[PID].get(oid))));                    return true;                }            };    /**     * Creates a new <code>GISTrainer</code> instance which does     * not print progress messages about training to STDOUT.     *     */    GISTrainer() {        super();    }    /**     * Creates a new <code>GISTrainer</code> instance.     *     * @param printMessages sends progress messages about training to     *                      STDOUT when true; trains silently otherwise.     */    GISTrainer(boolean printMessages) {        this();        this.printMessages = printMessages;    }    /**     * Sets whether this trainer will use smoothing while training the model.     * This can improve model accuracy, though training will potentially take     * longer and use more memory.  Model size will also be larger.     *     * @param smooth true if smoothing is desired, false if not     */    public void setSmoothing (boolean smooth) {	_simpleSmoothing = smooth;    }    /**     * Sets whether this trainer will use smoothing while training the model.     * This can improve model accuracy, though training will potentially take     * longer and use more memory.  Model size will also be larger.     *     * @param timesSeen the "number" of times we want the trainer to imagine     *                  it saw a feature that it actually didn't see     */    public void setSmoothingObservation (double timesSeen) {	_smoothingObservation = timesSeen;    }    /**     * Train a model using the GIS algorithm.     *     * @param eventStream The EventStream holding the data on which this model     *                    will be trained.     * @param iterations  The number of GIS iterations to perform.     * @param cutoff      The number of times a predicate must be seen in order     *                    to be relevant for training.     * @return The newly trained model, which can be used immediately or saved     *         to disk using an opennlp.maxent.io.GISModelWriter object.     */    public GISModel trainModel(EventStream eventStream,                               int iterations,                               int cutoff) {        DataIndexer di = new DataIndexer(eventStream, cutoff);	        /************** Incorporate all of the needed info ******************/        display("Incorporating indexed data for training...  \n");        contexts = di.contexts;	outcomes = di.outcomeList;        numTimesEventsSeen = di.numTimesEventsSeen;        numTokens = contexts.length;	        //printTable(contexts);        // determine the correction constant and its inverse        constant = contexts[0].length;        for (TID=1; TID<contexts.length; TID++) {            if (contexts[TID].length > constant) {	      constant = contexts[TID].length;            }        }        constantInverse = 1.0/constant;			display("done.\n");

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -