⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 weightedffneighborhoodinit.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//** *    WeightedFFNeighborhoodInit.java * *    Initializer that uses weighted farthest first traversal to get *    initial clusters for K-Means * *    Copyright (C) 2004 Sugato Basu, Misha Bilenko * */package weka.clusterers.initializers; import  java.io.*;import  java.util.*;import  weka.core.*;import  weka.core.metrics.*;import  weka.clusterers.*;public class WeightedFFNeighborhoodInit extends MPCKMeansInitializer {  /** holds the ([instance pair] -> [type of constraint]) mapping,      where the hashed value stores the type of link but the instance      pair does not hold the type of constraint - it holds (instanceIdx1,      instanceIdx2, DONT_CARE_LINK). This is done to make lookup easier      in future   */  protected HashMap m_ConstraintsHash;  /** stores the ([instanceIdx] -> [ArrayList of constraints])      mapping, where the arraylist contains the constraints in which      instanceIdx is involved. Note that the instance pairs stored in      the Arraylist have the actual link type.    */  protected HashMap m_instanceConstraintHash;   /** holds the points involved in the constraints */  protected HashSet m_SeedHash;  /** distance Metric */  protected LearnableMetric m_metric;  /** Is the objective function increasing or decreasing?  Depends on type   * of metric used:  for similarity-based metric, increasing, for distance-based - decreasing */  protected boolean m_objFunDecreasing;  /** Seedable or not (true by default) */  protected boolean m_Seedable = true;  /** Number of clusters in the process*/  protected int m_NumCurrentClusters = 0;   /**   * holds the random number generator used in various parts of the code   */  protected Random m_RandomNumberGenerator;  /** temporary variable holding cluster assignments while iterating */  protected int [] m_ClusterAssignments;  /** array holding sum of cluster instances */  Instance [] m_SumOfClusterInstances;  /** Instances without labels */  protected Instances m_Instances;  /** Instances with labels */  protected Instances m_TotalTrainWithLabels;  /** adjacency list for random */  protected HashSet[] m_AdjacencyList;  /** neighbor list for active learning: points in each cluster neighborhood */  protected HashSet[] m_NeighborSets;  /**   * holds the global centroids   */  protected Instance m_GlobalCentroid;  /**   * holds the default perturbation value for randomPerturbInit   */  protected double m_DefaultPerturb = 0.7;  protected boolean m_verbose = false;    /** colors for DFS */  final int WHITE = 300;  final int GRAY = 301;  final int BLACK = 302;  /** number of neighborhood sets */  protected int m_numNeighborhoods;  /** Default constructors */  public WeightedFFNeighborhoodInit() {    super();  }   /** Initialize with a clusterer */  public WeightedFFNeighborhoodInit (MPCKMeans clusterer) {    super(clusterer);  }  /** The main method for initializing cluster centroids   */  public Instances initialize() throws Exception {    System.out.println("Num clusters = " + m_numClusters);    m_Instances = m_clusterer.getInstances();    m_TotalTrainWithLabels = m_clusterer.getTotalTrainWithLabels();    m_ConstraintsHash = m_clusterer.getConstraintsHash();    m_instanceConstraintHash = m_clusterer.getInstanceConstraintsHash();    m_SeedHash = m_clusterer.getSeedHash();    m_Seedable = m_clusterer.getSeedable();    m_metric = m_clusterer.getMetric();    m_RandomNumberGenerator = m_clusterer.getRandomNumberGenerator();    m_objFunDecreasing = m_clusterer.getMetric().isDistanceBased();        m_NeighborSets = new HashSet[m_Instances.numInstances()];    m_AdjacencyList = new HashSet[m_Instances.numInstances()];    m_ClusterAssignments = new int [m_Instances.numInstances()];    boolean m_isOfflineMetric = m_clusterer.getIsOfflineMetric();    Instances m_ClusterCentroids = m_clusterer.getClusterCentroids();    boolean m_useTransitiveConstraints = m_clusterer.getUseTransitiveConstraints();    boolean m_isSparseInstance = (m_Instances.instance(0) instanceof SparseInstance) ?       true: false;    if (m_isSparseInstance) {      m_SumOfClusterInstances = new SparseInstance[m_Instances.numInstances()];    } else {      m_SumOfClusterInstances = new Instance[m_Instances.numInstances()];    }        for (int i=0; i<m_Instances.numInstances(); i++) {      m_ClusterAssignments[i] = -1;    }    if (m_ConstraintsHash != null) {      Set pointPairs = (Set) m_ConstraintsHash.keySet();       Iterator pairItr = pointPairs.iterator();      System.out.println("In non-active init");            // iterate over the pairs in ConstraintHash      while( pairItr.hasNext() ){	InstancePair pair = (InstancePair) pairItr.next();	int linkType = ((Integer) m_ConstraintsHash.get(pair)).intValue();	if (m_verbose)	  System.out.println(pair + ": type = " + linkType);	if( linkType == InstancePair.MUST_LINK ){ // mainly concerned with MUST-LINK	  if (m_AdjacencyList[pair.first] == null) {	    m_AdjacencyList[pair.first] = new HashSet();	  }	  if (!m_AdjacencyList[pair.first].contains(new Integer(pair.second))) {	    m_AdjacencyList[pair.first].add(new Integer(pair.second));	  }	  	  if (m_AdjacencyList[pair.second] == null) {	    m_AdjacencyList[pair.second] = new HashSet();	  }	  if (!m_AdjacencyList[pair.second].contains(new Integer(pair.first))) {	    m_AdjacencyList[pair.second].add(new Integer(pair.first));	  }	}      }            // DFS for finding connected components, updates requires stats      DFS();    }        // print out cluster assignments right here!!    if (m_ConstraintsHash.size() > 0) {       if (m_metric instanceof BarHillelMetric) {	System.out.println("Starting building BarHillel metric ...\n\n");	((BarHillelMetric) m_metric).buildAttributeMatrix(m_Instances, m_ClusterAssignments);	System.out.println("Finished building BarHillel metric!!\n\n");      } else if (m_metric instanceof XingMetric) {	((XingMetric) m_metric).buildAttributeMatrix(m_Instances, m_ConstraintsHash);      } else if (m_metric instanceof BarHillelMetricMatlab) {	System.out.println("Starting building BarHillelMatlab metric ...\n\n");	((BarHillelMetricMatlab) m_metric).buildAttributeMatrix(m_Instances, m_ClusterAssignments);	System.out.println("Finished building BarHillelMatlab metric!!\n\n");      }    }        if (!m_Seedable) { // don't perform any seeding, initialize from random      m_NumCurrentClusters = 0;      System.out.println("Not performing any seeding!");      for (int i=0; i<m_Instances.numInstances(); i++) {	m_ClusterAssignments[i] = -1;      }    }    // if the required number of clusters has been obtained, wrap-up    if( m_NumCurrentClusters >= m_numClusters ){      {//if (m_verbose) {	System.out.println("Got the required number of clusters ...");	System.out.println("num clusters: " + m_numClusters + ", num current clusters: " + m_NumCurrentClusters);      }      int clusterSizes[] = new int[m_NumCurrentClusters];      for (int i=0; i<m_NumCurrentClusters; i++) {	if (m_verbose) {	  System.out.println("Neighbor set: " + i + " has size: " + m_NeighborSets[i].size());	}	clusterSizes[i] = -m_NeighborSets[i].size(); // for reverse sort      }	      int[] indices = Utils.sort(clusterSizes);      System.out.println("Total neighborhoods:  " + m_NumCurrentClusters + ";  Sorted neighborhood sizes:  ");      // store number of neighborhoods after DFS      m_numNeighborhoods = m_NumCurrentClusters;      for (int i=0; i < m_NumCurrentClusters; i++) {	System.out.print(m_NeighborSets[indices[i]].size());	if (m_TotalTrainWithLabels.classIndex() >= 0) {	  System.out.println("(" + m_TotalTrainWithLabels.instance(((Integer) (m_NeighborSets[indices[i]].iterator().next())).intValue()).classValue()+ ")\t");	} else {	  System.out.println();	}      }            Instance[] clusterCentroids = new Instance[m_NumCurrentClusters];            // Added: Code for better random selection of neighborhoods, using weighted farthest first      for (int i=0; i<m_NumCurrentClusters; i++) { 	if (m_isSparseInstance) {	  clusterCentroids[i] = new SparseInstance(m_SumOfClusterInstances[i]);	}	else {	  clusterCentroids[i] = new Instance(m_SumOfClusterInstances[i]);	}	clusterCentroids[i].setWeight(m_NeighborSets[i].size()); // setting weight = neighborhood size	clusterCentroids[i].setDataset(m_Instances);	if (!m_objFunDecreasing) {	  ClusterUtils.normalize(clusterCentroids[i]);	} else {	  ClusterUtils.normalizeByWeight(clusterCentroids[i]);	}      }      HashSet selectedNeighborhoods = new HashSet((int) (m_numClusters/0.75 + 10));      System.out.println("Initializing " + m_numClusters + " clusters");      if (m_isOfflineMetric) {	System.out.println("Offline metric - using random neighborhoods");	for (int i=0; i<m_numClusters; i++) {	  int next = m_RandomNumberGenerator.nextInt(m_numNeighborhoods);	  while (selectedNeighborhoods.contains(new Integer (next))) {	    next = m_RandomNumberGenerator.nextInt(m_numNeighborhoods);	  }	  System.out.print("Neighborhood selected:  " + next);	  if (m_TotalTrainWithLabels.classIndex() >= 0) {	    System.out.println("(" + m_TotalTrainWithLabels.instance(((Integer)(m_NeighborSets[next].iterator().next())).intValue()).classValue()+ ")\t");	  } else {	    System.out.println();	  }	  selectedNeighborhoods.add(new Integer(next));		}      } else {	System.out.println("Learnable metric - using weighted FF to select neighborhoods");	selectedNeighborhoods.add(new Integer(indices[0])); // initializing with largest neighborhood	System.out.print("First neighborhood selected:  " + m_NeighborSets[indices[0]].size());	if (m_TotalTrainWithLabels.classIndex() >= 0) {	  System.out.println("(" + m_TotalTrainWithLabels.instance(((Integer)(m_NeighborSets[indices[0]].iterator().next())).intValue()).classValue()+ ")\t");	} else {	  System.out.println();	}		HashSet selectedNeighborhood = new HashSet();	System.out.println("Initializing rest by weightedFarthestFromSetOfPoints");	for (int i=1; i<m_numClusters; i++) {	  int next = (int) weightedFarthestFromSetOfPoints(clusterCentroids, selectedNeighborhoods, null);	  selectedNeighborhoods.add(new Integer(next));	  System.out.print("Neighborhood selected:  " + m_NeighborSets[next].size());	  if (m_TotalTrainWithLabels.classIndex() >= 0) {	    System.out.println("(" + m_TotalTrainWithLabels.instance(((Integer)(m_NeighborSets[next].iterator().next())).intValue()).classValue()+ ")\t");	  } else {	    System.out.println();	  }	}      }      // compute centroids of m_numClusters clusters from selectedNeighborhoods      m_ClusterCentroids = new Instances(m_Instances, m_numClusters);      Iterator neighborhoodIter = selectedNeighborhoods.iterator();       int num=0; // cluster number      while (neighborhoodIter.hasNext()) {	int i = ((Integer) neighborhoodIter.next()).intValue();	if (m_SumOfClusterInstances[i] != null) {	  if (m_verbose) {	    System.out.println("Normalizing instance " + i);	  }	  if (!m_objFunDecreasing) {	    ClusterUtils.normalize(m_SumOfClusterInstances[i]);	  }	  else {	    ClusterUtils.normalizeByWeight(m_SumOfClusterInstances[i]);	  }	}	Iterator iter = m_NeighborSets[i].iterator();	while (iter.hasNext()) { // assign points of new cluster	  int instNumber = ((Integer) iter.next()).intValue();	  if (m_verbose) {	    System.out.println("Assigning " + instNumber + " to cluster: " + num);	  }	  m_ClusterAssignments[instNumber] = num;	}	m_SumOfClusterInstances[num].setDataset(m_Instances);	m_ClusterCentroids.add(m_SumOfClusterInstances[i]);	num++;      }      for (int j=0; j < m_NumCurrentClusters; j++) {	int i = indices[j];	if (!selectedNeighborhoods.contains(new Integer(i))) { // not assigned as centroid	  Iterator iter = m_NeighborSets[i].iterator();	  while (iter.hasNext()) {	    int instNumber = ((Integer) iter.next()).intValue();	    if (m_verbose) {	      System.out.println("Assigning " + instNumber + " to cluster -1");	    }	    m_ClusterAssignments[instNumber] = -1;	  }	}      }      m_NumCurrentClusters = m_numClusters;      // adding other inferred ML and CL links      if (m_useTransitiveConstraints) { 	addMLAndCLTransitiveClosure(indices);	System.out.println("Adding constraints by transitive closure");      } else {	  System.out.println("Not adding constraints by transitive closure");      }    } else if( m_NumCurrentClusters < m_numClusters ){      // make random for rest      // adding other inferred ML and CL links      if (m_useTransitiveConstraints) { 	addMLAndCLTransitiveClosure(null);      }            System.out.println("Found " + m_NumCurrentClusters + " neighborhoods ...");      System.out.println("Will have to start " + (m_numClusters - m_NumCurrentClusters) + " clusters at random");	      // compute centroids of m_NumCurrentClusters clusters      m_ClusterCentroids = new Instances(m_Instances, m_numClusters);      for (int i=0; i<m_NumCurrentClusters; i++) {	if (m_SumOfClusterInstances[i] != null) {	  if (m_verbose) {	    System.out.println("Normalizing cluster center " + i);	  }	  if (!m_objFunDecreasing) {	    ClusterUtils.normalize(m_SumOfClusterInstances[i]);	  } else {	    ClusterUtils.normalizeByWeight(m_SumOfClusterInstances[i]);	  }	}	m_SumOfClusterInstances[i].setDataset(m_Instances);	m_ClusterCentroids.add(m_SumOfClusterInstances[i]);      }      // find global centroid      double [] globalValues = new double[m_Instances.numAttributes()];      if (m_isSparseInstance) {	globalValues = ClusterUtils.meanOrMode(m_Instances); // uses fast meanOrMode      } else {	for (int j = 0; j < m_Instances.numAttributes(); j++) {	  globalValues[j] = m_Instances.meanOrMode(j); // uses usual meanOrMode	}      }            System.out.println("Done calculating global centroid");      // global centroid is dense in SPKMeans      m_GlobalCentroid = new Instance(1.0, globalValues);      m_GlobalCentroid.setDataset(m_Instances);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -