corefclusteradv.java

来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 1,911 行 · 第 1/5 页

JAVA
1,911
字号
				PseudoVertex v22 = (PseudoVertex)i2.next();				Set s22 = v22.getCluster();				s22.addAll(s1);				s22.addAll(s2);	    }		}		return newScore;	}	public double computeInitialTreeObjScore (Collection pvertices) {		Collection c1 = (Collection)new ArrayList();		for (Iterator ii = pvertices.iterator(); ii.hasNext(); ) {	    PseudoVertex ppv = (PseudoVertex)ii.next();	    Collection c2 = (Collection)new ArrayList();	    c2.add((Citation)ppv.getObject());	    c1.add(c2);		}		return treeModel.computeTreeObjFn(c1);	}		public Collection absoluteCluster (InstanceList ilist, List mentions) {		List mCopy = new ArrayList();		boolean newCluster;		java.util.Random r = new java.util.Random();		int numTries = 0;		// get pseudo edges and sort em		HashMap vsToPvs = new HashMap();		// vsToPvs set destructively		Collection pvertices = createPseudoVertices (ilist, mentions, vsToPvs);		for (Iterator i1 = pvertices.iterator(); i1.hasNext(); ) {	    PseudoVertex v = (PseudoVertex)i1.next();	    mCopy.add(v);		}		List pedges = createPseudoEdges (ilist, vsToPvs);		Collections.sort(pedges,new PseudoEdgeComparator());		double initialObjVal = computeInitialObjFnVal (pedges);		System.out.println("initial obj fn: " + initialObjVal);		double objFnVal = initialObjVal;		double prevVal = -10000000000.0;		int i = 0;		while (true) {	    prevVal = objFnVal;	    PseudoEdge pedge = (PseudoEdge)pedges.get(i); // choosePseudoEdge (pedges, r);	    PseudoVertex v1 = (PseudoVertex)pedge.getV1();	    PseudoVertex v2 = (PseudoVertex)pedge.getV2();	    Set s1 = new LinkedHashSet();	    Set s2 = new LinkedHashSet();	    // make copies of the sets	    for (Iterator i1 = v1.getCluster().iterator(); i1.hasNext(); ) {				s1.add(i1.next());	    }	    for (Iterator i1 = v2.getCluster().iterator(); i1.hasNext(); ) {				s2.add(i1.next());	    }	    // this is the case for when the edge is now irrelevent through	    // transitive closure - just remove it and start at beginning	    if (s1.contains(v2) || s2.contains(v1)) {				//pedges.remove(i);				continue;	    }	    double[] n = new double[]{0.0};	    objFnVal = updateScore (objFnVal, n , v1, v2, s1, s2, false);	    i++;	    if (objFnVal <= prevVal) {				numTries++;				objFnVal = prevVal; // reset it, since we don't commit to this	    }	    else {				numTries = 0;	    }	    if (numTries > MAX_REDUCTIONS)				break;	    //System.out.println("ObjFnVal: " + objFnVal);		}		// build a proper graph from the edges since		// the evaluation code relies on this structure		this.wgraph = buildGraphFromPseudoEdges (pedges, mentions);		Collection citClustering = new ArrayList();		for (Iterator i1 = pvertices.iterator(); i1.hasNext(); ) {	    PseudoVertex v1 = (PseudoVertex)i1.next();	    Collection cluster = v1.getCluster();	    newCluster = true;	    if (citClustering.size() == 0)				newCluster = true;	    for (Iterator i2 = citClustering.iterator(); i2.hasNext(); ) {				Collection c2 = (Collection)i2.next();				if (c2.containsAll(cluster)) {					newCluster = false;					break;				}	    }	    if (newCluster) {				citClustering.add(cluster);	    }		}		Collection realClustering = new ArrayList();		for (Iterator i1 = citClustering.iterator(); i1.hasNext(); ) {	    Collection s1 = (Collection)i1.next();	    Collection n1 = new ArrayList();	    for (Iterator i2 = s1.iterator(); i2.hasNext(); ) {				n1.add((Citation)((PseudoVertex)i2.next()).getObject());	    }	    realClustering.add (n1);		}		return (Collection)realClustering;	}	protected Collection getPseudoClustering (Collection pvertices) {		Collection citClustering = new ArrayList();		boolean newCluster;		for (Iterator i1 = pvertices.iterator(); i1.hasNext(); ) {	    PseudoVertex v1 = (PseudoVertex)i1.next();	    Collection cluster = v1.getCluster();	    newCluster = true;	    if (citClustering.size() == 0)				newCluster = true;	    for (Iterator i2 = citClustering.iterator(); i2.hasNext(); ) {				Collection c2 = (Collection)i2.next();				if (c2.containsAll(cluster)) {					newCluster = false;					break;				}	    }	    if (newCluster) {				citClustering.add(cluster);	    }		}		return citClustering;	}	public Collection typicalClusterAdv (InstanceList ilist, List mentions) {		List pedgesC = null;		Collection pvertices = null;		for (int j=0; j < MAX_ITERS; j++) {	    List mCopy = new ArrayList();	    java.util.Random r = new java.util.Random();	    int numTries = 0;	    // get pseudo edges and sort em	    HashMap vsToPvs = new HashMap();	    // vsToPvs set destructively	    pvertices = createPseudoVertices (ilist, mentions, vsToPvs);	    for (Iterator i1 = pvertices.iterator(); i1.hasNext(); ) {				PseudoVertex v = (PseudoVertex)i1.next();				mCopy.add(v);	    }	    List pedges = createPseudoEdges (ilist, vsToPvs);	    pedgesC = new ArrayList();	    for (Iterator it = pedges.iterator(); it.hasNext();) {				pedgesC.add((PseudoEdge)it.next());	    }	    // build the actual graph for scoring puposees	    this.wgraph = buildGraphFromPseudoEdges (pedgesC, mentions);	    Collections.sort(pedges,new PseudoEdgeComparator());	    for (int k=0; k < 50; k++) {				PseudoEdge e1 = (PseudoEdge)pedges.get(k);	    }	    double initialObjVal = computeInitialObjFnVal (pedges);	    double initialTreeObjVal;	    if (treeModel != null)				initialTreeObjVal = computeInitialTreeObjScore (pvertices);	    else				initialTreeObjVal = 0.0;				    double[] treeObjVal;	    treeObjVal = new double[]{initialTreeObjVal};	    double objFnVal = initialObjVal;	    double prevVal = -10000000000.0;	    int i = 0;	    int numClusters = pvertices.size();	    while (true) {				prevVal = objFnVal;				int choice = r.nextInt(rBeamSize); // selection size				if (choice > pedges.size())					break;				PseudoEdge pedge = (PseudoEdge)pedges.get(choice); // choosePseudoEdge (pedges, r);				PseudoVertex v1 = (PseudoVertex)pedge.getV1();				PseudoVertex v2 = (PseudoVertex)pedge.getV2();				Set s1 = new LinkedHashSet();				Set s2 = new LinkedHashSet();				// make copies of the sets of vertices represented by each pseudovertex				for (Iterator i1 = v1.getCluster().iterator(); i1.hasNext(); ) {					s1.add(i1.next());				}				for (Iterator i1 = v2.getCluster().iterator(); i1.hasNext(); ) {					s2.add(i1.next());				}				// this is the case for when the edge is now irrelevent through				// transitive closure - just remove it and start at beginning				if (s1.contains(v2) || s2.contains(v1)) {					pedges.remove(choice);					continue;				}				numClusters--;				if (trueNumStop) {					objFnVal = updateScore (objFnVal, treeObjVal, v1, v2, s1, s2, true);					//Collection cl1 = (Collection)getClusteringFromPseudo(pvertices);					//System.out.println("+++++++++++++++++++++++++++");					//double treeV = treeModel.computeTreeObjFn(cl1, false);					//System.out.println("+++++++++++++++++++++++++++");					//System.out.println("treeV: " + treeV + " updated: " + treeObjVal[0]);				}				else					objFnVal = updateScore (objFnVal, treeObjVal, v1, v2, s1, s2, false);				if (!trueNumStop && objFnVal <= prevVal) {					numTries++;					objFnVal = prevVal; // reset it, since we don't commit to this				}				else {					//System.out.println(objFnVal + "," + treeObjVal[0]);					pedges.remove(choice);					numTries = 0;				}				if (trueNumStop && numClusters <= keyPartitioning.size()) {					Collection cl1 = (Collection)getClusteringFromPseudo(pvertices);					PairEvaluate pairEval = new PairEvaluate (keyPartitioning, cl1);					pairEval.evaluate();					double curAgree = evaluatePartitioningAgree (cl1, this.wgraph);					double curDisAgree = evaluatePartitioningDisAgree (cl1, this.wgraph);					/*						double treeV = treeModel.computeTreeObjFn(cl1, false);						if (Math.abs(treeV - treeObjVal[0]) > 0.01)						System.out.println("Tree values don't match: " + treeV + ":" + treeObjVal[0]);*/					int singles = numSingletons(cl1);					System.out.println(objFnVal + "," + treeObjVal[0] + "," + curAgree														 + "," + curDisAgree + "," + singles + "," + pairEval.getF1());					break;				}				else if (numTries > MAX_REDUCTIONS) {					Collection cl1 = (Collection)getClusteringFromPseudo(pvertices);					PairEvaluate pairEval = new PairEvaluate (keyPartitioning, cl1);						pairEval.evaluate();					System.out.println(objFnVal + "," + treeObjVal[0] + "," + pairEval.getF1());					break;				}				//System.out.println("ObjFnVal: " + objFnVal);	    }		}		// build a proper graph from the edges since		// the evaluation code relies on this structure			return (Collection)getClusteringFromPseudo(pvertices);	}	protected int numSingletons (Collection clustering) {		int total = 0;		for (Iterator it = clustering.iterator(); it.hasNext(); ) {	    if (((Collection)it.next()).size() == 1)				total++;		}		return total;	}	protected Collection getClusteringFromPseudo (Collection pvertices) {		Collection citClustering = getPseudoClustering (pvertices);		Collection realClustering = new ArrayList();		for (Iterator i1 = citClustering.iterator(); i1.hasNext(); ) {	    Collection s1 = (Collection)i1.next();	    Collection n1 = new ArrayList();	    for (Iterator i2 = s1.iterator(); i2.hasNext(); ) {				n1.add((Citation)((PseudoVertex)i2.next()).getObject());	    }	    realClustering.add (n1);		}		return (Collection)realClustering;	}	protected WeightedGraph buildGraphFromPseudoEdges (List pedges, List mentions) {		HashMap alreadyAdded = new HashMap();		WeightedGraph w = (WeightedGraph)new WeightedGraphImpl();		for (Iterator it = pedges.iterator(); it.hasNext(); ) {	    constructEdgesFromPseudoEdges (w, (PseudoEdge)it.next(), alreadyAdded);		}		addVerticesToGraph (w, mentions, alreadyAdded);		return w;	}	public Collection typicalClusterPartition (WeightedGraph graph) {		/*			Iterator i0 = ((Set)graph.getVertexSet()).iterator();			while (i0.hasNext()) {			VertexImpl v = (VertexImpl)i0.next();			System.out.println("Map: " + v.getObject() + " -> " +			((Citation)((Node)v.getObject()).getObject()).getBaseString() );			}*/		//System.out.println("Top Graph: " + graph);		while (true) {	    double bestEdgeVal = -100000000;	    WeightedEdge bestEdge = null;	    //System.out.println("Top Graph: " + graph);	    Set edgeSet = graph.getEdgeSet();	    Iterator i1 = edgeSet.iterator();	    // get highest edge value in this loop	    while (i1.hasNext()) {				WeightedEdge e1 = (WeightedEdge)i1.next();				if (e1.getWeight() > bestEdgeVal) {					bestEdgeVal = e1.getWeight();					bestEdge = e1;				}	    }			System.err.println ("bestEdgeVal: " + bestEdgeVal + " threshold: " + threshold);	    if (bestEdgeVal < threshold)				break;	    else {				if (bestEdge != null) {					VertexImpl v1 = (VertexImpl)bestEdge.getVertexA();					VertexImpl v2 = (VertexImpl)bestEdge.getVertexB();					/*						System.out.println("Best edge val: " + bestEdgeVal);						System.out.println("Merging vertices: " + v1.getObject()						+ " and " + v2.getObject()); */					mergeVertices(graph, v1, v2);				}	    }		}		System.out.println("Final graph now has " + graph.getVertexSet().size() + " nodes");		return getCollectionOfOriginalObjects ((Collection)graph.getVertexSet());	}	public Collection partitionGraph (WeightedGraph origGraph) {		java.util.Random rand = new java.util.Random();		double bestCost = -100000000000.0;		double curCost = bestCost;		Collection bestPartitioning = null;		// evalFreq is the frequency with which evaluations occur		// in the early stages, it is silly to keep doing a complete		// evaluation of the objective fn				for (int i=0; i < MAX_ITERS; i++) {	    //System.out.println("Iteration " + i);	    double cost = -100000000.0;	    double bCost = cost;	    int evalFreq = 10;	    // this is a counter that will increment each time	    // a new edge is tried and the result is a graph	    // with a reduced total objective value	    int numReductions = 0;			double treeCost = 0.0;	    int iter = 0;	    Collection curPartitioning = null;	    Collection localBestPartitioning = null;	    WeightedGraph graph = copyGraph(origGraph);	    WeightedGraph graph1 = copyGraph(graph);	    while (true) {				Collection c0 = (Collection)graph.getEdgeSet();				List sortedEdges = Collections.list(Collections.enumeration(c0));				System.out.println("Size of sorted edges: " + sortedEdges.size());				if (sortedEdges.size() > 0) {					EdgeComparator comp = new EdgeComparator();					Collections.sort(sortedEdges, comp);					double minVal = ((WeightedEdge)sortedEdges.get(sortedEdges.size()-1)).getWeight();					double totalVal = 0.0;					Iterator il = (Iterator)sortedEdges.iterator();					while (il.hasNext()) {						totalVal += ((WeightedEdge)il.next()).getWeight();					}					totalVal += sortedEdges.size()*(-minVal);

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?