multiplecorefclusterer.java
来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 635 行 · 第 1/2 页
JAVA
635 行
/* Copyright (C) 2002 Dept. of Computer Science, Univ. of Massachusetts, Amherst This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This program toolkit free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. For more details see the GNU General Public License and the file README-LEGAL. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *//** @author Aron Culotta */package edu.umass.cs.mallet.projects.seg_plus_coref.coreference;import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.*;import edu.umass.cs.mallet.projects.seg_plus_coref.graphs.*;import salvo.jesus.graph.*;import salvo.jesus.graph.VertexImpl;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.classify.*;import edu.umass.cs.mallet.base.pipe.*;import edu.umass.cs.mallet.base.pipe.iterator.*;import edu.umass.cs.mallet.base.util.*;import java.util.*;import java.util.logging.*;import java.lang.*;import java.io.*;/** Clusters multiple objects simultaneously (e.g. Authors, venues, * and papers). Trains a separate MaxEnt classifier to weight edges * between objects of the same type. Partitions resulting graph such * that constraints between objects are respected (e.g. clustering two * authors requires their venues be in the same cluster)*/public class MultipleCorefClusterer extends CorefClusterAdv { private static Logger logger = MalletLogger.getLogger (MultipleCorefClusterer.class.getName()); /** Classifiers which determine the edge weights between two nodes * of the same type */ MaxEnt[] classifiers; /** Pipes to extract features from pairs of nodes */ Pipe[] pipes; /** Maps a citation type to which index into the instanceList has * this type*/ HashMap type2index; /** Maps a paperVertex to corresponding venueVertex. Enforces that * all papers in a cluster (paperVertex) have venues in the same * cluster (venueVertex) */ HashMap paperVertex2VenueVertex; /** Constants for the types of citation nodes */ public static final String PAPER = "paper"; public static final String VENUE = "venue"; public static final String AUTHOR = "author"; public MultipleCorefClusterer (Pipe[] _pipes) { this.pipes = _pipes; type2index = new HashMap(); paperVertex2VenueVertex = new HashMap (); } /** Train the underlying classifiers with "ilists" as * trainingData */ public void train (InstanceList[] ilists) { setIndices (ilists); classifiers = new MaxEnt[ilists.length]; for (int i=0; i < ilists.length; i++) { logger.info("Training Coreference Classifiers["+i+"]: "); classifiers[i] = trainClassifier (ilists[i]); } this.meClassifier = classifiers[0]; } /** Sets the mapping from citation type to index */ public void setIndices (InstanceList[] ilists) { for (int i=0; i < ilists.length; i++) setIndex (getType (ilists[i]), i); } /** Adds to "type2index" hash the mapping type->i */ private void setIndex (String type, int i) { if (type.equals (this.PAPER)) type2index.put (this.PAPER, new Integer (i)); else if (type.equals (this.VENUE)) type2index.put (this.VENUE, new Integer (i)); else if (type.equals (this.AUTHOR)) type2index.put (this.AUTHOR, new Integer (i)); else throw new IllegalArgumentException ("Unknown citation type: " + type); } /** Citation type in this ilist */ private String getType (InstanceList ilist) { NodePair mentionPair = (NodePair)ilist.getInstance(0).getSource(); Citation c = (Citation)mentionPair.getObject1(); return getType (c); } /** Citation type in VertexImpl */ private String getType (Collection c) { Iterator liter = c.iterator(); String type = null; while (liter.hasNext()) { String currType = getType ((Citation)liter.next()); if (type != null && !type.equals(currType)) throw new IllegalArgumentException ("SERIOUS ERROR: Cluster has nodes of type " + type + " AND type " + currType); type = currType; } return type; } /** Citation type in VertexImpl */ private String getType (VertexImpl v) { Object o = v.getObject(); List l = null; if (!(o instanceof List)) { l = new ArrayList (); l.add (o); } else l = (List) o; return getType(l); } /** Returns the type of citation of c */ private String getType (Citation c) { if (c instanceof PaperCitation) return this.PAPER; else if (c instanceof VenueCitation) return this.VENUE; else if (c instanceof AuthorCitation) return this.AUTHOR; else throw new IllegalArgumentException ("Unknown citation type: " + c.getClass().getName()); } public void testClassifiers (InstanceList[] ilists) { if (!typeOrdersMatch (ilists)) throw new IllegalArgumentException ("ilists types in testing not in same order as in training"); for (int i=0; i < ilists.length; i++) testClassifier (ilists[i], this.classifiers[i]); } /** True if the citation type in ilists[i] equals the citation type * seen in ilist[i] during training */ private boolean typeOrdersMatch (InstanceList[] ilists) { for (int i=0; i < ilists.length; i++) { Integer t = (Integer) type2index.get (getType (ilists[i])); if (t == null || !t.equals (new Integer (i))) return false; } return true; } /** Returns a list of collections representing the clustering of "ilists" */ public Collection[] clusterMentions (InstanceList[] ilists, List[] mentions, int optimalBest, boolean stochastic) { if (!typeOrdersMatch (ilists)) throw new IllegalArgumentException ("ilists types in clustering not in same order as in training"); if (classifiers == null) throw new IllegalStateException ("Must train classifiers before clustering"); if (optimalBest > 0) { throw new UnsupportedOperationException ("Not yet implemented for nBest clustering"); } else { if (fullPartition) { wgraph = createMultipleTypeGraph (ilists, mentions); logger.info ("Created Multi-Graph with " + wgraph.getVerticesCount() + " vertices and " + wgraph.getEdgesCount() + " edges"); if (type2index.get (this.VENUE) != null && type2index.get (this.PAPER) != null) this.paperVertex2VenueVertex = getPaper2VenueHash (wgraph); Collection clustering = partitionGraph (wgraph); return splitClusteringByType (clustering); } else if (stochastic) { throw new UnsupportedOperationException ("Not yet implemented for stochastic clustering"); } else { wgraph = createMultipleTypeGraph (ilists, mentions); logger.info ("Created Multi-Graph with " + wgraph.getVerticesCount() + " vertices and " + wgraph.getEdgesCount() + " edges"); if (type2index.get (this.VENUE) != null && type2index.get(this.PAPER) != null) this.paperVertex2VenueVertex = getPaper2VenueHash (wgraph); Collection clustering = typicalClusterPartition (wgraph); logger.info ("Resulting clustering of all types has " + clustering.size() + " clusters"); Collection[] ret = splitClusteringByType (clustering); for (int cint=0; cint < ret.length; cint++) logger.info ("clustering of type " + cint + " has " + ret[cint].size() + " clusters."); return ret; } } } private Collection[] splitClusteringByType (Collection clustering) { ArrayList[] ret = new ArrayList[type2index.size()]; for (int i=0; i < ret.length; i++) ret[i] = new ArrayList(); Iterator iter = clustering.iterator(); System.err.println ("Cluster types: " + type2index); while (iter.hasNext()) { Collection c = (Collection) iter.next(); String type = getType (c); int index = ((Integer)type2index.get(type)).intValue(); if (index >= type2index.size()) throw new IllegalArgumentException ("index " + index + " greater than number of citation types"); ret[index].add (c); } return ret; } private List getVenueVertices (WeightedGraph graph) { Iterator iter = graph.getVerticesIterator (); ArrayList venueVertices = new ArrayList (); while (iter.hasNext ()) { VertexImpl v = (VertexImpl)iter.next(); if (getType(v).equals(this.VENUE)) venueVertices.add (v); } return venueVertices; } /** Map each paper vertex to its corresponding venue vertex (if one exists) */ private HashMap getPaper2VenueHash (WeightedGraph graph) { logger.info ("creating paperVertex2VenueVertex hash..."); HashMap hash = new HashMap (); Iterator iter = graph.getVerticesIterator (); ArrayList paperVertices = new ArrayList (); ArrayList venueVertices = new ArrayList (); while (iter.hasNext ()) { VertexImpl v = (VertexImpl)iter.next(); if (getType(v).equals(this.PAPER)) paperVertices.add (v); else if (getType(v).equals(this.VENUE)) venueVertices.add (v); } logger.info ("found " + paperVertices.size() + " paper vertices and " + venueVertices.size() + " venue vertices"); for (int i=0; i < paperVertices.size(); i++) { VertexImpl v = (VertexImpl)paperVertices.get (i); Object o = v.getObject(); List l = null; if (!(o instanceof List)) { l = new ArrayList (); l.add (o); } else l = (List) o; Iterator liter = l.iterator(); VertexImpl venueVertex = null; while (liter.hasNext()) { Citation c = (Citation) liter.next(); VertexImpl currVenue = findVenueVertexForPaperCitation (c, venueVertices); logger.info ("Venue for citation " + c + "\nis\n" + currVenue); if (venueVertex != null && !currVenue.equals(venueVertex)) throw new IllegalArgumentException ("Coreferent papers have NON-coreferent venues in cluster " + v); venueVertex = currVenue; } if (venueVertex == null) logger.warning ("Can't find venue vertex for citation " + v); else hash.put (v, venueVertex); } return hash; } /** Given PaperCitation c, find the VenueCitation vertex of c's venue */ private VertexImpl findVenueVertexForPaperCitation (Citation c, List venueVertices) { if (!getType(c).equals(this.PAPER)) throw new IllegalArgumentException ("Citation has type " + getType(c) + ", not " + this.PAPER); String venueID = c.getField (Citation.venueID); if (venueID == "") // no venue for this paper return null; VertexImpl ret = null; for (int i=0; i < venueVertices.size(); i++) { VertexImpl v = (VertexImpl)venueVertices.get(i); Object o = v.getObject(); List l = null; if (o instanceof Citation) { l = new ArrayList(); l.add ((Citation) o); } else l = (List) o; Iterator iter = l.iterator(); while (iter.hasNext()) { Citation vc = (Citation) iter.next(); String currVenueID = vc.getField (Citation.venueID); if (currVenueID == "") throw new IllegalArgumentException ("VenueCitation has no venueID: " + vc); if (currVenueID.equals(venueID)) { // found it logger.info ("Found venue id " + venueID); l.remove (v); ret = v; return v; } } } if (ret == null) throw new IllegalArgumentException ("Can't find venue vertex for citation " + c); return null; }
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?