📄 weightedffneighborhoodinit.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//** * WeightedFFNeighborhoodInit.java * * Initializer that uses weighted farthest first traversal to get * initial clusters for K-Means * * Copyright (C) 2004 Sugato Basu, Misha Bilenko * */package weka.clusterers.initializers; import java.io.*;import java.util.*;import weka.core.*;import weka.core.metrics.*;import weka.clusterers.*;public class WeightedFFNeighborhoodInit extends MPCKMeansInitializer { /** holds the ([instance pair] -> [type of constraint]) mapping, where the hashed value stores the type of link but the instance pair does not hold the type of constraint - it holds (instanceIdx1, instanceIdx2, DONT_CARE_LINK). This is done to make lookup easier in future */ protected HashMap m_ConstraintsHash; /** stores the ([instanceIdx] -> [ArrayList of constraints]) mapping, where the arraylist contains the constraints in which instanceIdx is involved. Note that the instance pairs stored in the Arraylist have the actual link type. */ protected HashMap m_instanceConstraintHash; /** holds the points involved in the constraints */ protected HashSet m_SeedHash; /** distance Metric */ protected LearnableMetric m_metric; /** Is the objective function increasing or decreasing? Depends on type * of metric used: for similarity-based metric, increasing, for distance-based - decreasing */ protected boolean m_objFunDecreasing; /** Seedable or not (true by default) */ protected boolean m_Seedable = true; /** Number of clusters in the process*/ protected int m_NumCurrentClusters = 0; /** * holds the random number generator used in various parts of the code */ protected Random m_RandomNumberGenerator; /** temporary variable holding cluster assignments while iterating */ protected int [] m_ClusterAssignments; /** array holding sum of cluster instances */ Instance [] m_SumOfClusterInstances; /** Instances without labels */ protected Instances m_Instances; /** Instances with labels */ protected Instances m_TotalTrainWithLabels; /** adjacency list for random */ protected HashSet[] m_AdjacencyList; /** neighbor list for active learning: points in each cluster neighborhood */ protected HashSet[] m_NeighborSets; /** * holds the global centroids */ protected Instance m_GlobalCentroid; /** * holds the default perturbation value for randomPerturbInit */ protected double m_DefaultPerturb = 0.7; protected boolean m_verbose = false; /** colors for DFS */ final int WHITE = 300; final int GRAY = 301; final int BLACK = 302; /** number of neighborhood sets */ protected int m_numNeighborhoods; /** Default constructors */ public WeightedFFNeighborhoodInit() { super(); } /** Initialize with a clusterer */ public WeightedFFNeighborhoodInit (MPCKMeans clusterer) { super(clusterer); } /** The main method for initializing cluster centroids */ public Instances initialize() throws Exception { System.out.println("Num clusters = " + m_numClusters); m_Instances = m_clusterer.getInstances(); m_TotalTrainWithLabels = m_clusterer.getTotalTrainWithLabels(); m_ConstraintsHash = m_clusterer.getConstraintsHash(); m_instanceConstraintHash = m_clusterer.getInstanceConstraintsHash(); m_SeedHash = m_clusterer.getSeedHash(); m_Seedable = m_clusterer.getSeedable(); m_metric = m_clusterer.getMetric(); m_RandomNumberGenerator = m_clusterer.getRandomNumberGenerator(); m_objFunDecreasing = m_clusterer.getMetric().isDistanceBased(); m_NeighborSets = new HashSet[m_Instances.numInstances()]; m_AdjacencyList = new HashSet[m_Instances.numInstances()]; m_ClusterAssignments = new int [m_Instances.numInstances()]; boolean m_isOfflineMetric = m_clusterer.getIsOfflineMetric(); Instances m_ClusterCentroids = m_clusterer.getClusterCentroids(); boolean m_useTransitiveConstraints = m_clusterer.getUseTransitiveConstraints(); boolean m_isSparseInstance = (m_Instances.instance(0) instanceof SparseInstance) ? true: false; if (m_isSparseInstance) { m_SumOfClusterInstances = new SparseInstance[m_Instances.numInstances()]; } else { m_SumOfClusterInstances = new Instance[m_Instances.numInstances()]; } for (int i=0; i<m_Instances.numInstances(); i++) { m_ClusterAssignments[i] = -1; } if (m_ConstraintsHash != null) { Set pointPairs = (Set) m_ConstraintsHash.keySet(); Iterator pairItr = pointPairs.iterator(); System.out.println("In non-active init"); // iterate over the pairs in ConstraintHash while( pairItr.hasNext() ){ InstancePair pair = (InstancePair) pairItr.next(); int linkType = ((Integer) m_ConstraintsHash.get(pair)).intValue(); if (m_verbose) System.out.println(pair + ": type = " + linkType); if( linkType == InstancePair.MUST_LINK ){ // mainly concerned with MUST-LINK if (m_AdjacencyList[pair.first] == null) { m_AdjacencyList[pair.first] = new HashSet(); } if (!m_AdjacencyList[pair.first].contains(new Integer(pair.second))) { m_AdjacencyList[pair.first].add(new Integer(pair.second)); } if (m_AdjacencyList[pair.second] == null) { m_AdjacencyList[pair.second] = new HashSet(); } if (!m_AdjacencyList[pair.second].contains(new Integer(pair.first))) { m_AdjacencyList[pair.second].add(new Integer(pair.first)); } } } // DFS for finding connected components, updates requires stats DFS(); } // print out cluster assignments right here!! if (m_ConstraintsHash.size() > 0) { if (m_metric instanceof BarHillelMetric) { System.out.println("Starting building BarHillel metric ...\n\n"); ((BarHillelMetric) m_metric).buildAttributeMatrix(m_Instances, m_ClusterAssignments); System.out.println("Finished building BarHillel metric!!\n\n"); } else if (m_metric instanceof XingMetric) { ((XingMetric) m_metric).buildAttributeMatrix(m_Instances, m_ConstraintsHash); } else if (m_metric instanceof BarHillelMetricMatlab) { System.out.println("Starting building BarHillelMatlab metric ...\n\n"); ((BarHillelMetricMatlab) m_metric).buildAttributeMatrix(m_Instances, m_ClusterAssignments); System.out.println("Finished building BarHillelMatlab metric!!\n\n"); } } if (!m_Seedable) { // don't perform any seeding, initialize from random m_NumCurrentClusters = 0; System.out.println("Not performing any seeding!"); for (int i=0; i<m_Instances.numInstances(); i++) { m_ClusterAssignments[i] = -1; } } // if the required number of clusters has been obtained, wrap-up if( m_NumCurrentClusters >= m_numClusters ){ {//if (m_verbose) { System.out.println("Got the required number of clusters ..."); System.out.println("num clusters: " + m_numClusters + ", num current clusters: " + m_NumCurrentClusters); } int clusterSizes[] = new int[m_NumCurrentClusters]; for (int i=0; i<m_NumCurrentClusters; i++) { if (m_verbose) { System.out.println("Neighbor set: " + i + " has size: " + m_NeighborSets[i].size()); } clusterSizes[i] = -m_NeighborSets[i].size(); // for reverse sort } int[] indices = Utils.sort(clusterSizes); System.out.println("Total neighborhoods: " + m_NumCurrentClusters + "; Sorted neighborhood sizes: "); // store number of neighborhoods after DFS m_numNeighborhoods = m_NumCurrentClusters; for (int i=0; i < m_NumCurrentClusters; i++) { System.out.print(m_NeighborSets[indices[i]].size()); if (m_TotalTrainWithLabels.classIndex() >= 0) { System.out.println("(" + m_TotalTrainWithLabels.instance(((Integer) (m_NeighborSets[indices[i]].iterator().next())).intValue()).classValue()+ ")\t"); } else { System.out.println(); } } Instance[] clusterCentroids = new Instance[m_NumCurrentClusters]; // Added: Code for better random selection of neighborhoods, using weighted farthest first for (int i=0; i<m_NumCurrentClusters; i++) { if (m_isSparseInstance) { clusterCentroids[i] = new SparseInstance(m_SumOfClusterInstances[i]); } else { clusterCentroids[i] = new Instance(m_SumOfClusterInstances[i]); } clusterCentroids[i].setWeight(m_NeighborSets[i].size()); // setting weight = neighborhood size clusterCentroids[i].setDataset(m_Instances); if (!m_objFunDecreasing) { ClusterUtils.normalize(clusterCentroids[i]); } else { ClusterUtils.normalizeByWeight(clusterCentroids[i]); } } HashSet selectedNeighborhoods = new HashSet((int) (m_numClusters/0.75 + 10)); System.out.println("Initializing " + m_numClusters + " clusters"); if (m_isOfflineMetric) { System.out.println("Offline metric - using random neighborhoods"); for (int i=0; i<m_numClusters; i++) { int next = m_RandomNumberGenerator.nextInt(m_numNeighborhoods); while (selectedNeighborhoods.contains(new Integer (next))) { next = m_RandomNumberGenerator.nextInt(m_numNeighborhoods); } System.out.print("Neighborhood selected: " + next); if (m_TotalTrainWithLabels.classIndex() >= 0) { System.out.println("(" + m_TotalTrainWithLabels.instance(((Integer)(m_NeighborSets[next].iterator().next())).intValue()).classValue()+ ")\t"); } else { System.out.println(); } selectedNeighborhoods.add(new Integer(next)); } } else { System.out.println("Learnable metric - using weighted FF to select neighborhoods"); selectedNeighborhoods.add(new Integer(indices[0])); // initializing with largest neighborhood System.out.print("First neighborhood selected: " + m_NeighborSets[indices[0]].size()); if (m_TotalTrainWithLabels.classIndex() >= 0) { System.out.println("(" + m_TotalTrainWithLabels.instance(((Integer)(m_NeighborSets[indices[0]].iterator().next())).intValue()).classValue()+ ")\t"); } else { System.out.println(); } HashSet selectedNeighborhood = new HashSet(); System.out.println("Initializing rest by weightedFarthestFromSetOfPoints"); for (int i=1; i<m_numClusters; i++) { int next = (int) weightedFarthestFromSetOfPoints(clusterCentroids, selectedNeighborhoods, null); selectedNeighborhoods.add(new Integer(next)); System.out.print("Neighborhood selected: " + m_NeighborSets[next].size()); if (m_TotalTrainWithLabels.classIndex() >= 0) { System.out.println("(" + m_TotalTrainWithLabels.instance(((Integer)(m_NeighborSets[next].iterator().next())).intValue()).classValue()+ ")\t"); } else { System.out.println(); } } } // compute centroids of m_numClusters clusters from selectedNeighborhoods m_ClusterCentroids = new Instances(m_Instances, m_numClusters); Iterator neighborhoodIter = selectedNeighborhoods.iterator(); int num=0; // cluster number while (neighborhoodIter.hasNext()) { int i = ((Integer) neighborhoodIter.next()).intValue(); if (m_SumOfClusterInstances[i] != null) { if (m_verbose) { System.out.println("Normalizing instance " + i); } if (!m_objFunDecreasing) { ClusterUtils.normalize(m_SumOfClusterInstances[i]); } else { ClusterUtils.normalizeByWeight(m_SumOfClusterInstances[i]); } } Iterator iter = m_NeighborSets[i].iterator(); while (iter.hasNext()) { // assign points of new cluster int instNumber = ((Integer) iter.next()).intValue(); if (m_verbose) { System.out.println("Assigning " + instNumber + " to cluster: " + num); } m_ClusterAssignments[instNumber] = num; } m_SumOfClusterInstances[num].setDataset(m_Instances); m_ClusterCentroids.add(m_SumOfClusterInstances[i]); num++; } for (int j=0; j < m_NumCurrentClusters; j++) { int i = indices[j]; if (!selectedNeighborhoods.contains(new Integer(i))) { // not assigned as centroid Iterator iter = m_NeighborSets[i].iterator(); while (iter.hasNext()) { int instNumber = ((Integer) iter.next()).intValue(); if (m_verbose) { System.out.println("Assigning " + instNumber + " to cluster -1"); } m_ClusterAssignments[instNumber] = -1; } } } m_NumCurrentClusters = m_numClusters; // adding other inferred ML and CL links if (m_useTransitiveConstraints) { addMLAndCLTransitiveClosure(indices); System.out.println("Adding constraints by transitive closure"); } else { System.out.println("Not adding constraints by transitive closure"); } } else if( m_NumCurrentClusters < m_numClusters ){ // make random for rest // adding other inferred ML and CL links if (m_useTransitiveConstraints) { addMLAndCLTransitiveClosure(null); } System.out.println("Found " + m_NumCurrentClusters + " neighborhoods ..."); System.out.println("Will have to start " + (m_numClusters - m_NumCurrentClusters) + " clusters at random"); // compute centroids of m_NumCurrentClusters clusters m_ClusterCentroids = new Instances(m_Instances, m_numClusters); for (int i=0; i<m_NumCurrentClusters; i++) { if (m_SumOfClusterInstances[i] != null) { if (m_verbose) { System.out.println("Normalizing cluster center " + i); } if (!m_objFunDecreasing) { ClusterUtils.normalize(m_SumOfClusterInstances[i]); } else { ClusterUtils.normalizeByWeight(m_SumOfClusterInstances[i]); } } m_SumOfClusterInstances[i].setDataset(m_Instances); m_ClusterCentroids.add(m_SumOfClusterInstances[i]); } // find global centroid double [] globalValues = new double[m_Instances.numAttributes()]; if (m_isSparseInstance) { globalValues = ClusterUtils.meanOrMode(m_Instances); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { globalValues[j] = m_Instances.meanOrMode(j); // uses usual meanOrMode } } System.out.println("Done calculating global centroid"); // global centroid is dense in SPKMeans m_GlobalCentroid = new Instance(1.0, globalValues); m_GlobalCentroid.setDataset(m_Instances);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -