📄 multinomial.java
字号:
size = index; if (counts.length <= index) { int newLength = ((counts.length < minCapacity) ? minCapacity : counts.length); while (newLength <= index) newLength *= 2; double[] newCounts = new double[newLength]; System.arraycopy (counts, 0, newCounts, 0, counts.length); this.counts = newCounts; } } // xxx Note that this does not reset the "size"! public void reset () { for (int i = 0; i < counts.length; i++) counts[i] = 0; } // xxx Remove this method? private void setCounts (double counts[]) { assert (dictionary == null || counts.length <= size()); // xxx Copy instead? // xxx Set size() to match counts.length? this.counts = counts; } public void increment (int index, double count) { ensureCapacity (index); counts[index] += count; if (size < index + 1) size = index + 1; } public void increment (String key, double count) { increment (dictionary.lookupIndex (key), count); } // xxx Add "public void increment (Object key, double count)", or is it too dangerous? public void increment (FeatureSequence fs, double scale) { if (fs.getAlphabet() != dictionary) throw new IllegalArgumentException ("Vocabularies don't match."); for (int fsi = 0; fsi < fs.size(); fsi++) increment (fs.getIndexAtPosition(fsi), scale); } public void increment (FeatureSequence fs) { increment (fs, 1.0); } public void increment (FeatureVector fv, double scale) { if (fv.getAlphabet() != dictionary) throw new IllegalArgumentException ("Vocabularies don't match."); for (int fvi = 0; fvi < fv.numLocations(); fvi++) increment (fv.indexAtLocation(fvi), scale); } public void increment (FeatureVector fv) { increment (fv, 1.0); } public double getCount (int index) { return counts[index]; } public Object clone () { try { return super.clone (); } catch (CloneNotSupportedException e) { return null; } } public void print () { //if (counts != null) throw new IllegalStateException ("Foo"); System.out.println ("Multinomial.Estimator"); for (int i = 0; i < size; i++) System.out.println ("counts["+i+"] = " + counts[i]); } public abstract Multinomial estimate (); // Serialization // serialVersionUID is overriden to prevent innocuous changes in this // class from making the serialization mechanism think the external // format has changed. private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); out.writeObject(dictionary); out.writeObject(counts); out.writeInt(size); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); if (version != CURRENT_SERIAL_VERSION) throw new ClassNotFoundException("Mismatched Multionmial.Estimator versions: wanted " + CURRENT_SERIAL_VERSION + ", got " + version); dictionary = (Alphabet) in.readObject(); counts = (double []) in.readObject(); size = in.readInt(); } } // class Estimator /** * An Estimator in which probability estimates in a Multinomial * are generated by adding a constant m (specified at construction time) * to each count before dividing by the total of the m-biased counts. */ public static class MEstimator extends Estimator { double m; public MEstimator (Alphabet dictionary, double m) { super (dictionary); this.m = m; } public MEstimator (int size, double m) { super(size); this.m = m; } public MEstimator (double m) { super(); this.m = m; } public Multinomial estimate () { double[] pr = new double[dictionary==null ? size : dictionary.size()]; if (dictionary != null){ ensureCapacity(dictionary.size() -1 ); //side effect: updates size member } double sum = 0; for (int i = 0; i < pr.length; i++) { pr[i] = counts[i] + m; sum += pr[i]; } for (int i = 0; i < pr.length; i++) pr[i] /= sum; return new Multinomial (pr, dictionary, size, false, false); } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); out.writeDouble(m); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); if (version != CURRENT_SERIAL_VERSION) throw new ClassNotFoundException("Mismatched Multinomial.MEstimator versions: wanted " + CURRENT_SERIAL_VERSION + ", got " + version); m = in.readDouble(); } } // end MEstimator /** * An MEstimator with m set to 0. The probability estimates in the Multinomial * are generated by dividing each count by the sum of all counts. */ public static class MLEstimator extends MEstimator { public MLEstimator () { super (0); } public MLEstimator (int size) { super (size, 0); } public MLEstimator (Alphabet dictionary) { super (dictionary, 0); } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); if (version != CURRENT_SERIAL_VERSION) throw new ClassNotFoundException("Mismatched Multinomial.MLEstimator versions: wanted " + CURRENT_SERIAL_VERSION + ", got " + version); } } // class MLEstimator /** * An MEstimator with m set to 1. The probability estimates in the Multinomial * are generated by adding 1 to each count and then dividing each * 1-biased count by the sum of all 1-biased counts. */ public static class LaplaceEstimator extends MEstimator { public LaplaceEstimator () { super (1); } public LaplaceEstimator (int size) { super (size, 1); } public LaplaceEstimator (Alphabet dictionary) { super (dictionary, 1); } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); if (version != CURRENT_SERIAL_VERSION) throw new ClassNotFoundException("Mismatched Multinomial.LaplaceEstimator versions: wanted " + CURRENT_SERIAL_VERSION + ", got " + version); } } // class Multinomial.LaplaceEstimator //todo: Lazy, lazy lazy. Make this serializable, too. /** * Unimplemented, but the MEstimators are. */ public static class MAPEstimator extends Estimator { Dirichlet prior; public MAPEstimator (Dirichlet d) { super (d.size()); prior = d; } public Multinomial estimate () { // xxx unfinished. return null; } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); if (version != CURRENT_SERIAL_VERSION) throw new ClassNotFoundException("Mismatched Multinomial.MAPEstimator versions: wanted " + CURRENT_SERIAL_VERSION + ", got " + version); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -