seededkmeans.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 2,032 行 · 第 1/5 页
JAVA
2,032 行
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SeededKMeans.java * Copyright (C) 2002 Sugato Basu * */package weka.clusterers;import java.io.*;import java.util.*;import weka.core.*;import weka.core.metrics.*;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;/** * Seeded k means clustering class. * * Valid options are:<p> * * -N <number of clusters> <br> * Specify the number of clusters to generate. <p> * * -R <random seed> <br> * Specify random number seed <p> * * -S <seeding method> <br> * The seeding method can be "seeded" (seeded KMeans) or "constrained" (constrained KMeans) * * -A <algorithm> <br> * The algorithm can be "simple" (simple KMeans) or "spherical" (spherical KMeans) * * -M <metric-class> <br> * Specifies the name of the distance metric class that should be used * * @author Sugato Basu(sugato@cs.utexas.edu) * @see Clusterer * @see OptionHandler */public class SeededKMeans extends Clusterer implements OptionHandler,SemiSupClusterer,ActiveLearningClusterer { /** Name of clusterer */ String m_name = "SeededKMeans"; /** holds the clusters */ protected ArrayList m_FinalClusters = null; /** holds the instance indices in the clusters */ protected ArrayList m_IndexClusters = null; /** holds the ([seed instance] -> [clusterLabel of seed instance]) mapping */ protected HashMap m_SeedHash = null; /** distance Metric */ protected Metric m_metric = new WeightedDotP(); /** has the metric has been constructed? a fix for multiple buildClusterer's */ protected boolean m_metricBuilt = false; /** starting index of test data in unlabeledData if transductive clustering */ protected int m_StartingIndexOfTest = -1; /** indicates whether instances are sparse */ protected boolean isSparseInstance = false; /** Is the objective function increasing or decreasing? Depends on type * of metric used: for similarity-based metric, increasing, for distance-based - decreasing */ protected boolean m_objFunDecreasing = false; /** Name of metric */ protected String m_metricName = new String("WeightedDotP"); /** Points that are to be skipped in the clustering process * because they are collapsed to zero */ protected HashSet m_skipHash = new HashSet(); /** Index of the current element in the E-step */ protected int m_currIdx = 0; /** keep track of the number of iterations completed before convergence */ protected int m_Iterations = 0; /* Define possible seeding methods */ public static final int SEEDING_CONSTRAINED = 1; public static final int SEEDING_SEEDED = 2; public static final Tag [] TAGS_SEEDING = { new Tag(SEEDING_CONSTRAINED, "Constrained seeding"), new Tag(SEEDING_SEEDED, "Initial seeding only") }; /** seeding method, by default seeded */ protected int m_SeedingMethod = SEEDING_SEEDED; /** Define possible algorithms */ public static final int ALGORITHM_SIMPLE = 1; public static final int ALGORITHM_SPHERICAL = 2; public static final Tag[] TAGS_ALGORITHM = { new Tag(ALGORITHM_SIMPLE, "Simple K-Means"), new Tag(ALGORITHM_SPHERICAL, "Spherical K-Means") }; /** algorithm, by default spherical */ protected int m_Algorithm = ALGORITHM_SPHERICAL; /** min difference of objective function values for convergence*/ protected double m_ObjFunConvergenceDifference = 1e-5; /** value of objective function */ protected double m_Objective = Integer.MAX_VALUE; /** returns objective function */ public double objectiveFunction() { return m_Objective; } /** Verbose? */ protected boolean m_Verbose = false; /** * training instances with labels */ protected Instances m_TotalTrainWithLabels; /** * training instances */ protected Instances m_Instances; /** * number of clusters to generate, default is 3 */ protected int m_NumClusters = 3; /** * m_FastMode = true => fast computation of meanOrMode in centroid calculation, useful for high-D data sets * m_FastMode = false => usual computation of meanOrMode in centroid calculation */ protected boolean m_FastMode = true; /** * holds the cluster centroids */ protected Instances m_ClusterCentroids; /** * holds the global centroids */ protected Instance m_GlobalCentroid; /** * holds the default perturbation value for randomPerturbInit */ protected double m_DefaultPerturb = 0.7; /** weight of the concentration */ protected double m_Concentration = 10.0; /** number of extra phase1 runs */ protected double m_ExtraPhase1RunFraction = 50; /** * temporary variable holding cluster assignments while iterating */ protected int [] m_ClusterAssignments; /** * holds the random Seed, useful for randomPerturbInit */ protected int m_randomSeed = 1; /** semisupervision */ protected boolean m_Seedable = true; /* Constructor */ public SeededKMeans() { } /* Constructor */ public SeededKMeans(Metric metric) { m_metric = metric; m_metricName = m_metric.getClass().getName(); m_objFunDecreasing = metric.isDistanceBased(); } /** * We always want to implement SemiSupClusterer from a class extending Clusterer. * We want to be able to return the underlying parent class. * @return parent Clusterer class */ public Clusterer getThisClusterer() { return this; } /** * Cluster given instances to form the specified number of clusters. * * @param data instances to be clustered * @param num_clusters number of clusters to create * @exception Exception if something goes wrong. */ public void buildClusterer(Instances data, int num_clusters) throws Exception { setNumClusters(num_clusters); if (m_Algorithm == ALGORITHM_SPHERICAL && m_metric instanceof WeightedDotP) { ((WeightedDotP)m_metric).setLengthNormalized(false); // since instances and clusters are already normalized, we don't need to normalize again while computing similarity - saves time } if (data.instance(0) instanceof SparseInstance) { isSparseInstance = true; } buildClusterer(data); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @param startingIndexOfTest from where test data starts in unlabeledData, useful if clustering is transductive * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters, int startingIndexOfTest) throws Exception { m_StartingIndexOfTest = startingIndexOfTest; buildClusterer(labeledData, unlabeledData, classIndex, numClusters); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @param startingIndexOfTest from where test data starts in unlabeledData, useful if clustering is transductive * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, Instances totalTrainWithLabels, int startingIndexOfTest) throws Exception { m_StartingIndexOfTest = startingIndexOfTest; m_TotalTrainWithLabels = totalTrainWithLabels; buildClusterer(labeledData, unlabeledData, classIndex, totalTrainWithLabels.numClasses()); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters) throws Exception { if (m_Algorithm == ALGORITHM_SPHERICAL) { if (labeledData != null) { for (int i=0; i<labeledData.numInstances(); i++) { normalize(labeledData.instance(i)); } } for (int i=0; i<unlabeledData.numInstances(); i++) { normalize(unlabeledData.instance(i)); } } Instances clusterData = new Instances(unlabeledData, 0);; if (getSeedable()) { // remove labels of labeledData before putting in seedHash clusterData = new Instances(labeledData); System.out.println("Numattributes: " + clusterData.numAttributes()); clusterData.deleteClassAttribute(); // create seedHash from labeledData Seeder seeder = new Seeder(clusterData, labeledData); setSeedHash(seeder.getAllSeeds()); } // add unlabeled data to labeled data (labels removed), not the // other way around, so that the labels in the hash table entries // and m_TotalTrainWithLabels are consistent for (int i=0; i<unlabeledData.numInstances(); i++) { clusterData.add(unlabeledData.instance(i)); } System.out.println("combinedData has size: " + clusterData.numInstances() + "\n"); // learn metric using labeled data, then cluster both the labeled and unlabeled data if (labeledData != null) { m_metric.buildMetric(labeledData); } else { m_metric.buildMetric(unlabeledData.numAttributes()); } m_metricBuilt = true; buildClusterer(clusterData, numClusters); } /** * Reset all values that have been learned */ public void resetClusterer() throws Exception{ if (m_metric instanceof LearnableMetric) ((LearnableMetric)m_metric).resetMetric(); m_SeedHash = null; m_ClusterCentroids = null; } /** * We can have clusterers that don't utilize seeding */ public boolean seedable() { return m_Seedable; } /** Initializes the cluster centroids - initial M step */ protected void initializeClusterer() { Random random = new Random(m_randomSeed); boolean globalCentroidComputed = false; if (m_Verbose) { // System.out.println("SeedHash is: " + m_SeedHash); } System.out.println("Initializing "); // makes initial cluster assignments for (int i = 0; i < m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); if (m_SeedHash != null && m_SeedHash.containsKey(inst)) { m_ClusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue(); if (m_ClusterAssignments[i] < 0) { m_ClusterAssignments[i] = -1; // For randomPerturbInit if (m_Verbose) { System.out.println("Invalid cluster specification for seed instance " + i + ": " + inst + ", making random initial assignment"); } } else { if (m_Verbose) { System.out.println("Seed instance " + i + ": " + inst + " assigned to cluster: " + m_ClusterAssignments[i]); } } } else { m_ClusterAssignments[i] = -1; // For randomPerturbInit } } Instances [] tempI = new Instances[m_NumClusters]; m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); boolean [] clusterSeeded = new boolean[m_NumClusters]; for (int i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i clusterSeeded[i] = false; // initialize all clusters to be unseeded } for (int i = 0; i < m_Instances.numInstances(); i++) { if (m_ClusterAssignments[i] >= 0) { // seeded cluster clusterSeeded[m_ClusterAssignments[i]] = true; tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i)); } } // Calculates initial cluster centroids for (int i = 0; i < m_NumClusters; i++) { double [] values = new double[m_Instances.numAttributes()]; if (clusterSeeded[i] == true) { if (m_FastMode && isSparseInstance) { values = meanOrMode(tempI[i]); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode } } } else { // finds global centroid if has not been already computed if (!globalCentroidComputed) { double [] globalValues = new double[m_Instances.numAttributes()]; if (m_FastMode && isSparseInstance) { globalValues = meanOrMode(m_Instances); // uses fast meanOrMode
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?