pcsoftkmeans.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 1,952 行 · 第 1/5 页
JAVA
1,952 行
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * PCSoftKMeans.java * Copyright (C) 2003 Sugato Basu * */package weka.clusterers;import java.io.*;import java.util.*;import weka.core.*;import weka.core.metrics.*;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;/** * Pairwise constrained k means clustering class. * * Valid options are:<p> * * -N <number of clusters> <br> * Specify the number of clusters to generate. <p> * * -R <random seed> <br> * Specify random number seed <p> * * -A <algorithm> <br> * The algorithm can be "Simple" (simple KMeans) or "Spherical" (spherical KMeans) * * -M <metric-class> <br> * Specifies the name of the distance metric class that should be used * * .... etc. * * @author Sugato Basu(sugato@cs.utexas.edu) * @see Clusterer * @see OptionHandler */public class PCSoftKMeans extends DistributionClusterer implements OptionHandler,SemiSupClusterer,ActiveLearningClusterer { /** Name of clusterer */ String m_name = "PCSoftKMeans"; /** holds the instances in the clusters */ protected ArrayList m_Clusters = null; /** holds the instance indices in the clusters, mapped to their probabilities */ protected HashMap[] m_IndexClusters = null; /** holds the ([instance pair] -> [type of constraint]) mapping. Note that the instance pairs stored in the hash always have constraint type InstancePair.DONT_CARE_LINK, the actual link type is stored in the hashed value */ protected HashMap m_ConstraintsHash = null; /** adjacency list for neighborhoods */ protected HashSet[] m_AdjacencyList; /** colors required for keeping track of DFS visit */ final int WHITE = 300; final int GRAY = 301; final int BLACK = 302; /** holds the points involved in the constraints */ protected HashSet m_SeedHash = null; /** weight to be given to each constraint */ protected double m_CannotLinkWeight = 1; /** weight to be given to each constraint */ protected double m_MustLinkWeight = 1; /** kappa value for vmf distribution */ protected double m_Kappa = 2; /** max kappa value for vmf distribution */ protected double m_MaxKappaSim = 100; protected double m_MaxKappaDist = 10; /** the maximum number of cannot-link constraints allowed */ protected final static int m_MaxConstraintsAllowed = 10000; /** verbose? */ protected boolean m_verbose = false; /** distance Metric */ protected Metric m_metric = new WeightedEuclidean(); /** has the metric has been constructed? a fix for multiple buildClusterer's */ protected boolean m_metricBuilt = false; /** indicates whether instances are sparse */ protected boolean m_isSparseInstance = false; /** Is the objective function increasing or decreasing? Depends on type of metric used: for similarity-based metric - increasing, for distance-based - decreasing */ protected boolean m_objFunDecreasing = true; /** Seedable or not (true by default) */ protected boolean m_Seedable = true; /** keep track of the number of iterations completed before convergence */ protected int m_Iterations = 0; /** Define possible algorithms */ public static final int ALGORITHM_SIMPLE = 1; public static final int ALGORITHM_SPHERICAL = 2; public static final Tag[] TAGS_ALGORITHM = { new Tag(ALGORITHM_SIMPLE, "Simple"), new Tag(ALGORITHM_SPHERICAL, "Spherical") }; /** algorithm, by default spherical */ protected int m_Algorithm = ALGORITHM_SIMPLE; /** min. absolute difference of objective function values for convergence */ protected double m_ObjFunConvergenceDifference = 1e-3; // difference less significant than 3rd place of decimal /** value of objective function */ protected double m_Objective; /** returns objective function */ public double objectiveFunction() { return m_Objective; } /** * training instances with labels */ protected Instances m_TotalTrainWithLabels; /** * training instances */ protected Instances m_Instances; /** A hash where the instance checksums are hashed */ protected HashMap m_checksumHash = null; protected double []m_checksumCoeffs = null; /** test data -- required to make sure that test points are not selected during active learning */ protected int m_StartingIndexOfTest = -1; /** * number of clusters to generate, default is -1 to get it from labeled data */ protected int m_NumClusters = 3; /** Number of clusters in the process*/ protected int m_NumCurrentClusters = 0; /** * m_FastMode = true => fast computation of meanOrMode in centroid calculation, useful for high-D data sets * m_FastMode = false => usual computation of meanOrMode in centroid calculation */ protected boolean m_FastMode = true; /** * holds the cluster centroids */ protected Instances m_ClusterCentroids; /** * holds the global centroids */ protected Instance m_GlobalCentroid; /** * holds the default perturbation value for randomPerturbInit */ protected double m_DefaultPerturb = 0.7; /** * holds the default merge threshold for matchMergeStep */ protected double m_MergeThreshold = 0.15; /** * temporary variable holding posterior cluster distribution of * points while iterating */ protected double [][] m_ClusterDistribution; /** * temporary variable holding cluster assignments while iterating */ protected int [] m_ClusterAssignments; /** * temporary variable holding cluster sums while iterating */ protected Instance [] m_SumOfClusterInstances; /** * holds the random Seed, useful for randomPerturbInit */ protected int m_RandomSeed = 42; /** neighbor list: points in each neighborhood inferred from constraints */ protected HashSet[] m_NeighborSets; /** assigned set for active learning: whether a point has been assigned or not */ HashSet m_AssignedSet; /* Constructor */ public PCSoftKMeans() { } /* Constructor */ public PCSoftKMeans(Metric metric) { m_metric = metric; m_objFunDecreasing = metric.isDistanceBased(); } /** * We always want to implement SemiSupClusterer from a class * extending Clusterer. We want to be able to return the underlying * parent class. * @return parent Clusterer class */ public Clusterer getThisClusterer() { return this; } /** * Generates a clusterer. Instances in data have to be either all * sparse or all non-sparse * * @param data set of instances serving as training data * @exception Exception if the clusterer has not been * generated successfully */ public void buildClusterer(Instances data) throws Exception { System.out.println("Must link weight: " + m_MustLinkWeight); System.out.println("Cannot link weight: " + m_CannotLinkWeight); setInstances(data); // Don't rebuild the metric if it was already trained if (!m_metricBuilt) { m_metric.buildMetric(data.numAttributes()); } m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); m_ClusterAssignments = new int [m_Instances.numInstances()]; if (m_Instances.checkForNominalAttributes() && m_Instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle nominal attributes\n"); } System.out.println("Initializing clustering ..."); nonActivePairwiseInit(); System.out.println("Done initializing clustering ..."); if (m_Seedable) { // System.out.println("Initial assignments of seed points:"); // printIndexClusters(); } if (m_verbose) { for (int i=0; i<m_NumClusters; i++) { System.out.println("Centroid " + i + ": " + m_ClusterCentroids.instance(i)); } } runEM(); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @param startingIndexOfTest from where test data starts in * unlabeledData, useful if clustering is transductive * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters, int startingIndexOfTest) throws Exception { // !!!! Dummy function, for compatibility with interface !!!! throw new Exception("Not implemented for PCSoftKMeans"); } /** * Cluster given instances to form the specified number of clusters. * * @param data instances to be clustered * @param num_clusters number of clusters to create * @exception Exception if something goes wrong. */ public void buildClusterer(Instances data, int num_clusters) throws Exception { m_NumClusters = num_clusters; if (m_Algorithm == ALGORITHM_SPHERICAL && m_metric instanceof WeightedDotP) { ((WeightedDotP)m_metric).setLengthNormalized(false); // since instances and clusters are already normalized, we don't need to normalize again while computing similarity - saves time } if (data.instance(0) instanceof SparseInstance) { m_isSparseInstance = true; } buildClusterer(data); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledTrainPairs labeled instances to be used as seeds * @param unlabeledData unlabeled training (+ test for transductive) instances * @param labeledTrain labeled training instances * @param startingIndexOfTest starting index of test set in unlabeled data * @exception Exception if something goes wrong. */ public void buildClusterer(ArrayList labeledPair, Instances unlabeledData, Instances labeledTrain, int startingIndexOfTest) throws Exception { int classIndex = labeledTrain.numAttributes(); // assuming that the last attribute is always the class m_TotalTrainWithLabels = labeledTrain; if (labeledPair != null) { m_SeedHash = new HashSet((int) (unlabeledData.numInstances()/0.75 + 10)) ; m_ConstraintsHash = new HashMap(m_MaxConstraintsAllowed); for (int i=0; i<labeledPair.size(); i++) { InstancePair pair = (InstancePair) labeledPair.get(i); Integer firstInt = new Integer(pair.first); Integer secondInt = new Integer(pair.second); // for first point if(!m_SeedHash.contains(firstInt)) { // add instances with constraints to seedHash m_SeedHash.add(firstInt); } // for second point if(!m_SeedHash.contains(secondInt)) { m_SeedHash.add(secondInt); } if (pair.first >= pair.second) { throw new Exception("Ordering reversed - something wrong!!"); } else { InstancePair newPair = new InstancePair(pair.first, pair.second, InstancePair.DONT_CARE_LINK); m_ConstraintsHash.put(newPair, new Integer(pair.linkType)); // WLOG first < second } } } // normalize all data for SPKMeans if (m_Algorithm == ALGORITHM_SPHERICAL) { for (int i=0; i<unlabeledData.numInstances(); i++) { normalize(unlabeledData.instance(i)); } } m_StartingIndexOfTest = startingIndexOfTest; if (m_verbose) { System.out.println("Starting index of test: " + m_StartingIndexOfTest); } // learn metric using labeled data, then cluster both the labeled and unlabeled data m_metric.buildMetric(unlabeledData.numAttributes()); m_metricBuilt = true; buildClusterer(unlabeledData, labeledTrain.numClasses()); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds -- NOT USED FOR PCSoftKMeans!!!
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?