📄 multinomial.java
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.types;import edu.umass.cs.mallet.base.types.Alphabet;import edu.umass.cs.mallet.base.types.FeatureSequence;import edu.umass.cs.mallet.base.types.FeatureVector;import edu.umass.cs.mallet.base.util.Random;import java.io.Serializable;import java.io.ObjectOutputStream;import java.io.IOException;import java.io.ObjectInputStream;/** * A probability distribution over a set of features represented as a {@link edu.umass.cs.mallet.base.types.FeatureVector}. * The values associated with each element in the Multinomial/FeaturVector are probabilities * and should sum to 1. * Features are indexed using feature indices - the index into the underlying Alphabet - * rather than using locations the way FeatureVectors do. * <p> * {@link edu.umass.cs.mallet.base.types.Multinomial.Estimator} provides a subhierachy * of ways to generate an estimate of the probability distribution from counts associated * with the features. * * @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */public class Multinomial extends FeatureVector{ // protected Multinomial () { } // "size" is the number of entries in "probabilities" that have valid values in them; // note that the dictionary (and thus the resulting multinomial) may be bigger than size // if the dictionary is shared with multiple estimators, and the dictionary grew // due to another estimator. private static double[] getValues (double[] probabilities, Alphabet dictionary, int size, boolean copy, boolean checkSum) { double[] values; assert (dictionary == null || dictionary.size() >= size); // No, not necessarily true; see comment above. //assert (dictionary == null || dictionary.size() == size); //assert (probabilities.length == size); // xxx Consider always copying, so that we are assured that we // always have a real probability distribution. if (copy) { values = new double[dictionary==null ? size : dictionary.size()]; System.arraycopy (probabilities, 0, values, 0, size); } else { assert (dictionary == null || dictionary.size() == probabilities.length); values = probabilities; } if (checkSum) { // Check that we have a true probability distribution double sum = 0; for (int i = 0; i < values.length; i++) sum += values[i]; if (Math.abs (sum - 1.0) > 0.9999){ throw new IllegalArgumentException ("Probabilities sum to " + sum + ", not to one."); } } return values; } protected Multinomial (double[] probabilities, Alphabet dictionary, int size, boolean copy, boolean checkSum) { super (dictionary, getValues(probabilities, dictionary, size, copy, checkSum)); } public Multinomial (double[] probabilities, Alphabet dictionary) { this (probabilities, dictionary, dictionary.size(), true, true); } public Multinomial (double[] probabilities, int size) { this (probabilities, null, size, true, true); } public Multinomial (double[] probabilities) { this (probabilities, null, probabilities.length, true, true); } public int size () { return values.length; } public double probability (int featureIndex) { return values[featureIndex]; } public double probability (Object key) { if (dictionary == null) throw new IllegalStateException ("This Multinomial has no dictionary."); return probability (dictionary.lookupIndex (key)); } public double logProbability (int featureIndex) { return Math.log(values[featureIndex]); } public double logProbability (Object key) { if (dictionary == null) throw new IllegalStateException ("This Multinomial has no dictionary."); return logProbability (dictionary.lookupIndex (key)); } public Alphabet getAlphabet () { return dictionary; } public void addProbabilitiesTo (double[] vector) { for (int i = 0; i < values.length; i++) vector[i] += values[i]; } public int randomIndex (Random r) { double f = r.nextUniform(); double sum = 0; int i; for (i = 0; i < values.length; i++) { sum += values[i]; //System.out.print (" sum="+sum); if (sum >= f) break; } //if (sum < f) throw new IllegalStateException //System.out.println ("i = "+i+", f = "+f+", sum = "+sum); assert (sum >= f); return i; } public Object randomObject (Random r) { if (dictionary == null) throw new IllegalStateException ("This Multinomial has no dictionary."); return dictionary.lookupObject (randomIndex (r)); } public FeatureSequence randomFeatureSequence (Random r, int length) { if (! (dictionary instanceof Alphabet)) throw new UnsupportedOperationException ("Multinomial's dictionary much be a Alphabet"); FeatureSequence fs = new FeatureSequence ((Alphabet)dictionary, length); while (length-- > 0) fs.add (randomIndex (r)); return fs; } // "size" is the number of 1.0-weight features in the feature vector public FeatureVector randomFeatureVector (Random r, int size) { return new FeatureVector (randomFeatureSequence (r, size)); } /** A Multinomial in which the values associated with each feature index fi is * Math.log(probability[fi]) instead of probability[fi]. * Logs are used for numerical stability. */ public static class Logged extends Multinomial { public Logged (double[] probabilities, Alphabet dictionary, int size, boolean areLoggedAlready) { super (probabilities, dictionary, size, true, !areLoggedAlready); assert (dictionary == null || dictionary.size() == size); if (!areLoggedAlready) for (int i = 0; i < size; i++) values[i] = Math.log (values[i]); } public Logged (double[] probabilities, Alphabet dictionary, boolean areLoggedAlready) { this (probabilities, dictionary, (dictionary == null ? probabilities.length : dictionary.size()), areLoggedAlready); } public Logged (double[] probabilities, Alphabet dictionary, int size) { this (probabilities, dictionary, size, false); } public Logged (double[] probabilities, Alphabet dictionary) { this (probabilities, dictionary, dictionary.size(), false); } public Logged (Multinomial m) { this (m.values, m.dictionary, false); } public Logged (double[] probabilities) { this (probabilities, null, false); } public double probability (int featureIndex) { return Math.exp (values[featureIndex]); } public double logProbability (int featureIndex) { return values[featureIndex]; } public void addProbabilities (double[] vector) { throw new UnsupportedOperationException ("Not implemented."); } public void addLogProbabilities (double[] vector) { for (int i = 0; i < values.length; i++) vector[i] += values[i]; // if vector is longer than values, act as if values // were extended with values of minus infinity. for (int i=values.length; i<vector.length; i++){ vector[i] = Double.NEGATIVE_INFINITY; } } } // xxx Make this inherit from something like AugmentableDenseFeatureVector /** * A hierarchy of classes used to produce estimates of probabilities, in * the form of a Multinomial, from counts associated with the elements * of an Alphabet. * * Estimator itself contains the machinery for associating and manipulating * counts with elements of an Alphabet, including behaving sanely if the * Alphabet changes size between calls. It does not contain any means * of generating probability estimates; various means of estimating are * provided by subclasses. */ public static abstract class Estimator implements Cloneable, Serializable { Alphabet dictionary; double counts[]; int size; // The number of valid entries in counts[] static final int minCapacity = 16; protected Estimator (double counts[], int size, Alphabet dictionary) { this.counts = counts; this.size = size; this.dictionary = dictionary; } public Estimator (double counts[], Alphabet dictionary) { this (counts, dictionary.size(), dictionary); } public Estimator () { this (new double[minCapacity], 0, null); } public Estimator (int size) { this (new double[size > minCapacity ? size : minCapacity], size, null); } public Estimator (Alphabet dictionary) { this(new double[dictionary.size()], dictionary.size(), dictionary); } public void setAlphabet (Alphabet d) { this.size = d.size(); this.counts = new double[size]; this.dictionary = d; } public int size () { return (dictionary == null ? size : dictionary.size()); } protected void ensureCapacity (int index) { //assert (dictionary == null); // Size is fixed if dictionary present? if (index > size)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -