multiplecorefclusterer.java

来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 635 行 · 第 1/2 页

JAVA
635
字号
 /* Copyright (C) 2002 Dept. of Computer Science, Univ. of Massachusetts, Amherst   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This program toolkit free software; you can redistribute it and/or   modify it under the terms of the GNU General Public License as   published by the Free Software Foundation; either version 2 of the   License, or (at your option) any later version.   This program is distributed in the hope that it will be useful, but   WITHOUT ANY WARRANTY; without even the implied warranty of   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  For more   details see the GNU General Public License and the file README-LEGAL.   You should have received a copy of the GNU General Public License   along with this program; if not, write to the Free Software   Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA   02111-1307, USA. *//**	 @author Aron Culotta */package edu.umass.cs.mallet.projects.seg_plus_coref.coreference;import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.*;import edu.umass.cs.mallet.projects.seg_plus_coref.graphs.*;import salvo.jesus.graph.*;import salvo.jesus.graph.VertexImpl;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.classify.*;import edu.umass.cs.mallet.base.pipe.*;import edu.umass.cs.mallet.base.pipe.iterator.*;import edu.umass.cs.mallet.base.util.*;import java.util.*;import java.util.logging.*;import java.lang.*;import java.io.*;/** Clusters multiple objects simultaneously (e.g. Authors, venues, * and papers). Trains a separate MaxEnt classifier to weight edges * between objects of the same type. Partitions resulting graph such * that constraints between objects are respected (e.g. clustering two * authors requires their venues be in the same cluster)*/public class MultipleCorefClusterer extends CorefClusterAdv {	private static Logger logger = MalletLogger.getLogger (MultipleCorefClusterer.class.getName());	/** Classifiers which determine the edge weights between two nodes	 * of the same type */	MaxEnt[] classifiers;	/** Pipes to extract features from pairs of nodes */	Pipe[] pipes;	/** Maps a citation type to which index into the instanceList has	 * this type*/	HashMap type2index;	/** Maps a paperVertex to corresponding venueVertex. Enforces that	 * all papers in a cluster (paperVertex) have venues in the same	 * cluster (venueVertex) */	HashMap paperVertex2VenueVertex;		/** Constants for the types of citation nodes */	public static final String PAPER = "paper";	public static final String VENUE = "venue";	public static final String AUTHOR = "author";			public MultipleCorefClusterer (Pipe[] _pipes) {		this.pipes = _pipes;		type2index = new HashMap();		paperVertex2VenueVertex = new HashMap ();	}	/** Train the underlying classifiers with "ilists" as	 * trainingData */	public void train (InstanceList[] ilists) {		setIndices (ilists);		classifiers = new MaxEnt[ilists.length];		for (int i=0; i < ilists.length; i++) {			logger.info("Training Coreference Classifiers["+i+"]: ");			classifiers[i] = trainClassifier (ilists[i]);		}		this.meClassifier = classifiers[0];	}	/** Sets the mapping from citation type to index */	public void setIndices (InstanceList[] ilists) {		for (int i=0; i < ilists.length; i++) 			setIndex (getType (ilists[i]), i);	}	/** Adds to "type2index" hash the mapping type->i */	private void setIndex (String type, int i) {		if (type.equals (this.PAPER))						type2index.put (this.PAPER, new Integer (i));		else if (type.equals (this.VENUE))						type2index.put (this.VENUE, new Integer (i));		else if (type.equals (this.AUTHOR))						type2index.put (this.AUTHOR, new Integer (i));		else			throw new IllegalArgumentException ("Unknown citation type: " + type);			}	/** Citation type in this ilist */	private String getType (InstanceList ilist) {			NodePair mentionPair = (NodePair)ilist.getInstance(0).getSource();						Citation c = (Citation)mentionPair.getObject1();			return getType (c);	}	 	/** Citation type in VertexImpl */	private String getType (Collection c) {		Iterator liter = c.iterator();		String type = null;		while (liter.hasNext()) {			String currType = getType ((Citation)liter.next());			if (type != null && !type.equals(currType))				throw new IllegalArgumentException ("SERIOUS ERROR: Cluster has nodes of type " +																						type + " AND type " + currType);			type = currType;		}		return type;	}	/** Citation type in VertexImpl */	private String getType (VertexImpl v) {		Object o = v.getObject();		List l = null;		if (!(o instanceof List)) {			l = new ArrayList ();			l.add (o);		}		else			l = (List) o;		return getType(l);	}		/** Returns the type of citation of c */	private String getType (Citation c) {		if (c instanceof PaperCitation)			return this.PAPER;		else if (c instanceof VenueCitation)			return this.VENUE;		else if (c instanceof AuthorCitation)			return this.AUTHOR;		else			throw new IllegalArgumentException ("Unknown citation type: " + c.getClass().getName());	}		public void testClassifiers (InstanceList[] ilists) {		if (!typeOrdersMatch (ilists))			throw new IllegalArgumentException ("ilists types in testing  not in same order as in training");		for (int i=0; i < ilists.length; i++)			testClassifier (ilists[i], this.classifiers[i]);	}	/** True if the citation type in ilists[i] equals the citation type	 * seen in ilist[i] during training */	private boolean typeOrdersMatch (InstanceList[] ilists) {		for (int i=0; i < ilists.length; i++) {			Integer t = (Integer) type2index.get (getType (ilists[i]));			if (t == null || !t.equals (new Integer (i)))				return false;		}		return true;	}	/** Returns a list of collections representing the clustering of "ilists" */	public Collection[] clusterMentions (InstanceList[] ilists, List[] mentions,																			 int optimalBest, boolean stochastic) {		if (!typeOrdersMatch (ilists))			throw new IllegalArgumentException ("ilists types in clustering not in same order as in training");					if (classifiers == null)			throw new IllegalStateException ("Must train classifiers before clustering");		if (optimalBest > 0) {			throw new UnsupportedOperationException ("Not yet implemented for nBest clustering");		}		else {			if (fullPartition) {				wgraph = createMultipleTypeGraph (ilists, mentions);				logger.info ("Created Multi-Graph with " + wgraph.getVerticesCount() + " vertices and " +										 wgraph.getEdgesCount() + " edges");				if (type2index.get (this.VENUE) != null &&						type2index.get (this.PAPER) != null)					this.paperVertex2VenueVertex = getPaper2VenueHash (wgraph);				Collection clustering = partitionGraph (wgraph);				return splitClusteringByType (clustering);			}			else if (stochastic) {				throw new UnsupportedOperationException ("Not yet implemented for stochastic clustering");			}			else {				wgraph = createMultipleTypeGraph (ilists, mentions);				logger.info ("Created Multi-Graph with " + wgraph.getVerticesCount() + " vertices and " + wgraph.getEdgesCount() + " edges");				if (type2index.get (this.VENUE) != null && type2index.get(this.PAPER) != null)					this.paperVertex2VenueVertex = getPaper2VenueHash (wgraph);				Collection clustering = typicalClusterPartition (wgraph);				logger.info ("Resulting clustering of all types has " + clustering.size() + " clusters");				Collection[] ret =  splitClusteringByType (clustering);				for (int cint=0; cint < ret.length; cint++) 					logger.info ("clustering of type " + cint + " has " + ret[cint].size() + " clusters.");				return ret;			}		}	}	private Collection[] splitClusteringByType (Collection clustering) {		ArrayList[] ret = new ArrayList[type2index.size()];		for (int i=0; i < ret.length; i++)			ret[i] = new ArrayList();		Iterator iter = clustering.iterator();		System.err.println ("Cluster types: " + type2index);		while (iter.hasNext()) { 			Collection c = (Collection) iter.next();			String type = getType (c);			int index = ((Integer)type2index.get(type)).intValue();			if (index >= type2index.size())				throw new IllegalArgumentException ("index " + index + " greater than number of citation types");			ret[index].add (c);		}		return ret;	}	private List getVenueVertices (WeightedGraph graph) {		Iterator iter = graph.getVerticesIterator ();		ArrayList venueVertices = new ArrayList ();		while (iter.hasNext ()) {			VertexImpl v = (VertexImpl)iter.next();			if (getType(v).equals(this.VENUE))				venueVertices.add (v);		}		return venueVertices;	}		/** Map each paper vertex to its corresponding venue vertex (if one exists) */	private HashMap getPaper2VenueHash (WeightedGraph graph) {		logger.info ("creating paperVertex2VenueVertex hash...");		HashMap hash = new HashMap ();		Iterator iter = graph.getVerticesIterator ();		ArrayList paperVertices = new ArrayList ();		ArrayList venueVertices = new ArrayList ();		while (iter.hasNext ()) {			VertexImpl v = (VertexImpl)iter.next();			if (getType(v).equals(this.PAPER))				paperVertices.add (v);			else if (getType(v).equals(this.VENUE))				venueVertices.add (v);		}		logger.info ("found " + paperVertices.size() + " paper vertices and " +								 venueVertices.size() + " venue vertices"); 		for (int i=0; i < paperVertices.size(); i++) {			VertexImpl v = (VertexImpl)paperVertices.get (i);			Object o = v.getObject();			List l = null;			if (!(o instanceof List)) {				l = new ArrayList ();				l.add (o);			}			else				l = (List) o;			Iterator liter = l.iterator();			VertexImpl venueVertex = null;						while (liter.hasNext()) {				Citation c = (Citation) liter.next();				VertexImpl currVenue = findVenueVertexForPaperCitation (c, venueVertices);				logger.info ("Venue for citation " + c + "\nis\n" + currVenue);				if (venueVertex != null && !currVenue.equals(venueVertex))					throw new IllegalArgumentException ("Coreferent papers have NON-coreferent venues in cluster " + v);				venueVertex = currVenue;			}			if (venueVertex == null)				logger.warning ("Can't find venue vertex for citation " + v);			else				hash.put (v, venueVertex);		}		return hash;	}	/** Given PaperCitation c, find the VenueCitation vertex of c's venue */	private VertexImpl findVenueVertexForPaperCitation (Citation c, List venueVertices) {		if (!getType(c).equals(this.PAPER))			throw new IllegalArgumentException ("Citation has type " + getType(c) + ", not " + this.PAPER);		String venueID = c.getField (Citation.venueID);		if  (venueID == "") // no venue for this paper			return null;		VertexImpl ret = null;		for (int i=0; i < venueVertices.size(); i++) {			VertexImpl v = (VertexImpl)venueVertices.get(i);			Object o = v.getObject();			List l = null;			if (o instanceof Citation) {				l = new ArrayList();				l.add ((Citation) o);			}			else 				l = (List) o;			Iterator iter = l.iterator();			while (iter.hasNext()) {				Citation vc = (Citation) iter.next();				String currVenueID = vc.getField (Citation.venueID);				if (currVenueID == "")					throw new IllegalArgumentException ("VenueCitation has no venueID: " + vc);				if (currVenueID.equals(venueID)) { // found it					logger.info ("Found venue id " + venueID);					l.remove (v);					ret = v;					return v;				}			}					}		if (ret == null)			throw new IllegalArgumentException ("Can't find venue vertex for citation " + c);		return null;	}	

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?