seededkmeans.java

来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 2,032 行 · 第 1/5 页

JAVA
2,032
字号
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    SeededKMeans.java *    Copyright (C) 2002 Sugato Basu  * */package weka.clusterers;import  java.io.*;import  java.util.*;import  weka.core.*;import  weka.core.metrics.*;import  weka.filters.Filter;import  weka.filters.unsupervised.attribute.Remove;/** * Seeded k means clustering class. * * Valid options are:<p> * * -N <number of clusters> <br> * Specify the number of clusters to generate. <p> * * -R <random seed> <br> * Specify random number seed <p> * * -S <seeding method> <br> * The seeding method can be "seeded" (seeded KMeans) or "constrained" (constrained KMeans) * * -A <algorithm> <br> * The algorithm can be "simple" (simple KMeans) or "spherical" (spherical KMeans) * * -M <metric-class> <br> * Specifies the name of the distance metric class that should be used *  * @author Sugato Basu(sugato@cs.utexas.edu) * @see Clusterer * @see OptionHandler */public class SeededKMeans extends Clusterer implements OptionHandler,SemiSupClusterer,ActiveLearningClusterer {  /** Name of clusterer */  String m_name = "SeededKMeans";  /** holds the clusters */  protected ArrayList m_FinalClusters = null;    /** holds the instance indices in the clusters */  protected ArrayList m_IndexClusters = null;  /** holds the ([seed instance] -> [clusterLabel of seed instance]) mapping */  protected HashMap m_SeedHash = null;  /** distance Metric */  protected Metric m_metric = new WeightedDotP();  /** has the metric has been constructed?  a fix for multiple buildClusterer's */  protected boolean m_metricBuilt = false;  /** starting index of test data in unlabeledData if transductive clustering */  protected int m_StartingIndexOfTest = -1;  /** indicates whether instances are sparse */  protected  boolean isSparseInstance = false;  /** Is the objective function increasing or decreasing?  Depends on type   * of metric used:  for similarity-based metric, increasing, for distance-based - decreasing */  protected boolean m_objFunDecreasing = false;  /** Name of metric */  protected String m_metricName = new String("WeightedDotP");  /** Points that are to be skipped in the clustering process   * because they are collapsed to zero */  protected HashSet m_skipHash = new HashSet();  /** Index of the current element in the E-step */  protected int m_currIdx = 0;  /** keep track of the number of iterations completed before convergence   */  protected int m_Iterations = 0;  /* Define possible seeding methods */  public static final int SEEDING_CONSTRAINED = 1;  public static final int SEEDING_SEEDED = 2;  public static final Tag [] TAGS_SEEDING = {    new Tag(SEEDING_CONSTRAINED, "Constrained seeding"),    new Tag(SEEDING_SEEDED, "Initial seeding only")      };    /** seeding method, by default seeded */  protected int m_SeedingMethod = SEEDING_SEEDED;    /** Define possible algorithms */  public static final int ALGORITHM_SIMPLE = 1;  public static final int ALGORITHM_SPHERICAL = 2;  public static final Tag[] TAGS_ALGORITHM = {    new Tag(ALGORITHM_SIMPLE, "Simple K-Means"),    new Tag(ALGORITHM_SPHERICAL, "Spherical K-Means")      };    /** algorithm, by default spherical */  protected int m_Algorithm = ALGORITHM_SPHERICAL;    /** min difference of objective function values for convergence*/  protected double m_ObjFunConvergenceDifference = 1e-5;  /** value of objective function */  protected double m_Objective  = Integer.MAX_VALUE;  /** returns objective function */  public double objectiveFunction() {    return m_Objective;  }  /** Verbose? */  protected boolean m_Verbose = false;  /**   * training instances with labels   */  protected Instances m_TotalTrainWithLabels;  /**   * training instances   */  protected Instances m_Instances;  /**   * number of clusters to generate, default is 3   */  protected int m_NumClusters = 3;  /**   * m_FastMode = true  => fast computation of meanOrMode in centroid calculation, useful for high-D data sets   * m_FastMode = false => usual computation of meanOrMode in centroid calculation   */  protected boolean m_FastMode = true;  /**   * holds the cluster centroids   */  protected Instances m_ClusterCentroids;  /**   * holds the global centroids   */  protected Instance m_GlobalCentroid;  /**   * holds the default perturbation value for randomPerturbInit   */  protected double m_DefaultPerturb = 0.7;  /** weight of the concentration */  protected double m_Concentration = 10.0;  /** number of extra phase1 runs */  protected double m_ExtraPhase1RunFraction = 50;  /**   * temporary variable holding cluster assignments while iterating   */  protected int [] m_ClusterAssignments;  /**   * holds the random Seed, useful for randomPerturbInit   */  protected int m_randomSeed = 1;  /** semisupervision */  protected boolean m_Seedable = true;    /* Constructor */  public SeededKMeans() {  }  /* Constructor */  public SeededKMeans(Metric metric) {    m_metric = metric;    m_metricName = m_metric.getClass().getName();    m_objFunDecreasing = metric.isDistanceBased();  }  /**   * We always want to implement SemiSupClusterer from a class extending Clusterer.     * We want to be able to return the underlying parent class.   * @return parent Clusterer class   */  public Clusterer getThisClusterer() {    return this;  }   /**   * Cluster given instances to form the specified number of clusters.   *   * @param data instances to be clustered   * @param num_clusters number of clusters to create   * @exception Exception if something goes wrong.   */  public void buildClusterer(Instances data, int num_clusters) throws Exception {    setNumClusters(num_clusters);    if (m_Algorithm == ALGORITHM_SPHERICAL && m_metric instanceof WeightedDotP) {      ((WeightedDotP)m_metric).setLengthNormalized(false); // since instances and clusters are already normalized, we don't need to normalize again while computing similarity - saves time    }    if (data.instance(0) instanceof SparseInstance) {      isSparseInstance = true;    }        buildClusterer(data);  }  /**   * Clusters unlabeledData and labeledData (with labels removed),   * using labeledData as seeds   *   * @param labeledData labeled instances to be used as seeds   * @param unlabeledData unlabeled instances   * @param classIndex attribute index in labeledData which holds class info   * @param numClusters number of clusters   * @param startingIndexOfTest from where test data starts in unlabeledData, useful if clustering is transductive   * @exception Exception if something goes wrong.  */  public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters, int startingIndexOfTest) throws Exception {    m_StartingIndexOfTest = startingIndexOfTest;    buildClusterer(labeledData, unlabeledData, classIndex, numClusters);  }  /**   * Clusters unlabeledData and labeledData (with labels removed),   * using labeledData as seeds   *   * @param labeledData labeled instances to be used as seeds   * @param unlabeledData unlabeled instances   * @param classIndex attribute index in labeledData which holds class info   * @param numClusters number of clusters   * @param startingIndexOfTest from where test data starts in unlabeledData, useful if clustering is transductive   * @exception Exception if something goes wrong.  */  public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, Instances totalTrainWithLabels, int startingIndexOfTest) throws Exception {    m_StartingIndexOfTest = startingIndexOfTest;    m_TotalTrainWithLabels = totalTrainWithLabels;    buildClusterer(labeledData, unlabeledData, classIndex, totalTrainWithLabels.numClasses());  }  /**   * Clusters unlabeledData and labeledData (with labels removed),   * using labeledData as seeds   *   * @param labeledData labeled instances to be used as seeds   * @param unlabeledData unlabeled instances   * @param classIndex attribute index in labeledData which holds class info   * @param numClusters number of clusters   * @exception Exception if something goes wrong.  */  public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters) throws Exception {    if (m_Algorithm == ALGORITHM_SPHERICAL) {      if (labeledData != null) {	for (int i=0; i<labeledData.numInstances(); i++) {	  normalize(labeledData.instance(i));	}      }      for (int i=0; i<unlabeledData.numInstances(); i++) {	normalize(unlabeledData.instance(i));      }    }    Instances clusterData = new Instances(unlabeledData, 0);;        if (getSeedable()) {    // remove labels of labeledData before putting in seedHash      clusterData = new Instances(labeledData);      System.out.println("Numattributes: " + clusterData.numAttributes());      clusterData.deleteClassAttribute();      // create seedHash from labeledData      Seeder seeder = new Seeder(clusterData, labeledData);      setSeedHash(seeder.getAllSeeds());    }    // add unlabeled data to labeled data (labels removed), not the    // other way around, so that the labels in the hash table entries    // and m_TotalTrainWithLabels are consistent    for (int i=0; i<unlabeledData.numInstances(); i++) {      clusterData.add(unlabeledData.instance(i));    }            System.out.println("combinedData has size: " + clusterData.numInstances() + "\n");    // learn metric using labeled data, then  cluster both the labeled and unlabeled data    if (labeledData != null) {      m_metric.buildMetric(labeledData);    }    else {      m_metric.buildMetric(unlabeledData.numAttributes());    }    m_metricBuilt = true;    buildClusterer(clusterData, numClusters);  }  /**   * Reset all values that have been learned   */  public void resetClusterer()  throws Exception{    if (m_metric instanceof LearnableMetric)      ((LearnableMetric)m_metric).resetMetric();    m_SeedHash = null;    m_ClusterCentroids = null;  }    /**   * We can have clusterers that don't utilize seeding   */  public boolean seedable() {    return m_Seedable;  }  /** Initializes the cluster centroids - initial M step   */  protected void initializeClusterer() {    Random random = new Random(m_randomSeed);    boolean globalCentroidComputed = false;    if (m_Verbose) {      //      System.out.println("SeedHash is: " + m_SeedHash);    }            System.out.println("Initializing ");    // makes initial cluster assignments    for (int i = 0; i < m_Instances.numInstances(); i++) {      Instance inst = m_Instances.instance(i);            if (m_SeedHash != null && m_SeedHash.containsKey(inst)) {	m_ClusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue();	if (m_ClusterAssignments[i] < 0) {	  m_ClusterAssignments[i] = -1; // For randomPerturbInit	  if (m_Verbose) {	    System.out.println("Invalid cluster specification for seed instance " + i + ": " + inst + ", making random initial assignment");	  }	}	else {	  if (m_Verbose) {	    System.out.println("Seed instance " + i + ": " + inst + " assigned to cluster: " + m_ClusterAssignments[i]);	  }	}      }      else {	m_ClusterAssignments[i] = -1; // For randomPerturbInit      }    }    Instances [] tempI = new Instances[m_NumClusters];    m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);    boolean [] clusterSeeded = new boolean[m_NumClusters];        for (int i = 0; i < m_NumClusters; i++) {      tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i      clusterSeeded[i] = false; // initialize all clusters to be unseeded    }    for (int i = 0; i < m_Instances.numInstances(); i++) {      if (m_ClusterAssignments[i] >= 0) { // seeded cluster	clusterSeeded[m_ClusterAssignments[i]] = true;	tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i));      }    }        // Calculates initial cluster centroids    for (int i = 0; i < m_NumClusters; i++) {      double [] values = new double[m_Instances.numAttributes()];      if (clusterSeeded[i] == true) { 	if (m_FastMode && isSparseInstance) {	  values = meanOrMode(tempI[i]); // uses fast meanOrMode	}	else {	  for (int j = 0; j < m_Instances.numAttributes(); j++) {	    values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode	  }	}      }      else {	// finds global centroid if has not been already computed	if (!globalCentroidComputed) {	  double [] globalValues = new double[m_Instances.numAttributes()];	  if (m_FastMode && isSparseInstance) {	    globalValues = meanOrMode(m_Instances); // uses fast meanOrMode

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?