⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 multivariateestimator.java

📁 一个自然语言处理的Java开源工具包。LingPipe目前已有很丰富的功能
💻 JAVA
字号:
/* * LingPipe v. 3.5 * Copyright (C) 2003-2008 Alias-i * * This program is licensed under the Alias-i Royalty Free License * Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the Alias-i * Royalty Free License Version 1 for more details. * * You should have received a copy of the Alias-i Royalty Free License * Version 1 along with this program; if not, visit * http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact * Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211, * +1 (718) 290-9170. */package com.aliasi.stats;import com.aliasi.util.AbstractExternalizable;import java.util.ArrayList;import java.util.HashMap;import java.io.IOException;import java.io.ObjectInput;import java.io.ObjectOutput;import java.io.Serializable;/** * A <code>MultivariateEstimator</code> provides a maximum likelihood * estimator of a multivariate distribution based on training samples. * Training is carried out by incrementing outcomes through {@link * #train(String,long)}.  At any point, the distribution provides a * maximum likelihood estimator. * * <P>Simple additive smoothing can be achieved through the API by * initially incrementing counts for all possible outcomes by one. * * <h3>Compilation and Serialization</h3> * * <p>Serialization simply stores the current multivariate estimator * and reconstructs it exactly as is under deserialization (that is, * the class of the deserialized object is * <code>MultivariateEstimator</code>).  Compilation stores a more * efficient and compact version of the estimator, which deserializes * to a <code>MultivariateDistribution</code> rather than a * <code>MultivariateEstimator</code>. * * @author  Bob Carpenter * @version 3.1.3 * @since   LingPipe2.0 */public class MultivariateEstimator    extends MultivariateDistribution    implements Serializable {    static final long serialVersionUID = 1171641384366463097L;    final HashMap mLabelToIndex;    final ArrayList mIndexToLabel;    final ArrayList mIndexToCount;    long mTotalCount = 0l;    int mNextIndex = 0;    /**     * Construct a multivariate estimator with no known outcomes     * or counts.     */    public MultivariateEstimator() {        this(new HashMap(), new ArrayList(), new ArrayList());    }    private MultivariateEstimator(HashMap labelToIndex,                                  ArrayList indexToLabel,                                  ArrayList indexToCount) {        mLabelToIndex = labelToIndex;        mIndexToLabel = indexToLabel;        mIndexToCount = indexToCount;    }    static void checkLongAddInRange(long a, long b) {        if (Long.MAX_VALUE - b < a) {            String msg = "Long addition overflow."                + " a=" + a                + " b=" + b;            throw new IllegalArgumentException(msg);        }    }    /**     * Resets the count for the specified outcome label to zero.     * Calling this method will also decrement the total count     * for this estimator.     *     * @param outcomeLabel Label of outcome that is reset.     * @throws IllegalArgumentException If the outcome label is not     * known.     */    public void resetCount(String outcomeLabel) {        Integer index = (Integer) mLabelToIndex.get(outcomeLabel);        if (index == null) {            String msg = "May only reset known outcomes."                + " Found outcome=" + outcomeLabel;            throw new IllegalArgumentException(msg);        }        long currentCount            = ((Long) mIndexToCount.get(index.intValue())).longValue();        mTotalCount -= currentCount;        mIndexToCount.set(index.intValue(),                          new Long(0));    }    /**     * Increment counts in this estimator for the specified     * outcome by the specified increment.     *     * @param outcomeLabel Label of sample outcome.     * @param increment Amount to increment count for outcome.     * @throws IllegalArgumentException If the result would     * be a count higher than the maximum long value or if the     * increment is less than one.     */    public void train(String outcomeLabel, long increment) {        if (increment < 1) {            String msg = "Increment must be positive."                + " Found increment=" + increment;            throw new IllegalArgumentException(msg);        }        mTotalCount += increment;        Integer indexInteger            = (Integer) mLabelToIndex.get(outcomeLabel);        if (indexInteger == null) {            int index = mNextIndex++;            mLabelToIndex.put(outcomeLabel,new Integer(index));            mIndexToLabel.add(index,outcomeLabel);            mIndexToCount.add(index,new Long(increment));            return;        }        int index = indexInteger.intValue();        long currentCount = ((Long) mIndexToCount.get(index)).intValue();        checkLongAddInRange(currentCount,increment);        mIndexToCount.set(index,new Long(currentCount + increment));    }    /**     * Return the outcome for the specified label.     *     * @param outcomeLabel Label whose outcome is returned.     * @return The outcome for the specified label.     */    public long outcome(String outcomeLabel) {        Integer outcome = (Integer) mLabelToIndex.get(outcomeLabel);        return outcome == null ? -1l : outcome.longValue();    }    /**     * Return the label for the specified outcome.     *     * @param outcome Outcome whose label is returned.     * @return The label for the specified outcome.     */    public String label(long outcome) {        if (outcome < 0l || outcome >= mNextIndex) {            String msg = "Outcome must be between 0 and max."                + " Max outcome=" + maxOutcome()                + " Argument outcome=" + outcome;            throw new IllegalArgumentException(msg);        }        return mIndexToLabel.get((int)outcome).toString();    }    /**     * Returns the number of dimensions for this multivariate     * estimator.     *     * @return The number of dimensions for this multivariate     * estimator.     */    public int numDimensions() {        return mIndexToLabel.size();    }    /**     * Returns the multivariate probability estimate for the specified     * outcome.     *     * @param outcome The outcome whose probability is returned.     * @return The probability of the specified outcome.     */    public double probability(long outcome) {        if (outcome < minOutcome() || outcome > maxOutcome())            return 0.0;        return ((double) getCount(outcome))            / (double) trainingSampleCount();    }    /**     * Returns the count in this estimator for the specified outcome.     *     * @param outcome The outcome whose probability is returned.     * @return The probability of the specified outcome.     * @throws IllegalArgumentException If the outcome is not between     * zero and the maximum outcome inclusive.     */    public long getCount(long outcome) {        checkOutcome(outcome);        Long count = (Long) mIndexToCount.get((int)outcome);        return count == null ? 0l : count.longValue();    }    /**     * Returns the count for the specified outcome.     *     * @param outcomeLabel Label of specified outcome.     * @return Count of outcome in this estimator.     * @throws IllegalArgumentException If the     */    public long getCount(String outcomeLabel) {        Integer index = (Integer) mLabelToIndex.get(outcomeLabel);        if (index == null) {            String msg = "May only count known outcomes by label."                + " Found outcome=" + outcomeLabel;            throw new IllegalArgumentException(msg);        }        return getCount(index.longValue());    }    /**     * Returns the total count of training sample.     *     * @return The total count for this estimator.     * @throws IllegalArgumentException If the outcome is not between     * zero and the maximum outcome inclusive.     */    public long trainingSampleCount() {        return mTotalCount;    }    /**     * Writes a constant version of this estimator to the specified     * object output.  The distribution read back in will be an     * instance of {@link MultivariateConstant} with the same     * distribution as the estimated distribution.     *     * @param objOut The object output to which this estimator is     * compiled.     * @throws IOException If there is an I/O exception writing to the     * output.     */    public void compileTo(ObjectOutput objOut) throws IOException {        objOut.writeObject(new Externalizer(this));    }    static class Externalizer extends AbstractExternalizable {        private static final long serialVersionUID = 2913496935213914118L;        final MultivariateEstimator mEstimator;        public Externalizer() { mEstimator = null; }        public Externalizer(MultivariateEstimator estimator) {            mEstimator = estimator;        }        public void writeExternal(ObjectOutput out)            throws IOException {            String[] labels = new String[mEstimator.mIndexToLabel.size()];            mEstimator.mIndexToLabel.toArray(labels);            out.writeObject(labels);            Long[] counts = new Long[mEstimator.mIndexToCount.size()];            mEstimator.mIndexToCount.toArray(counts);            double totalCount = mEstimator.mTotalCount;            double[] ratios = new double[counts.length];            for (int i = 0; i < ratios.length; ++i)                ratios[i] = counts[i].doubleValue()/totalCount;            out.writeObject(ratios);        }        public Object read(ObjectInput in)            throws ClassNotFoundException, IOException {            String[] labels = (String[]) in.readObject();            double[] ratios = (double[]) in.readObject();            return new MultivariateConstant(ratios,labels);        }    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -