📄 simplekmeans.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SimpleKMeans.java * Copyright (C) 2000 University of Waikato, Hamilton, New Zealand * */package weka.clusterers;import weka.classifiers.rules.DecisionTableHashKey;import weka.core.Attribute;import weka.core.Capabilities;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.RevisionUtils;import weka.core.Utils;import weka.core.WeightedInstancesHandler;import weka.core.Capabilities.Capability;import weka.filters.Filter;import weka.filters.unsupervised.attribute.ReplaceMissingValues;import java.util.Enumeration;import java.util.HashMap;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Cluster data using the k means algorithm * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N <num> * number of clusters. * (default 2).</pre> * * <pre> -V * Display std. deviations for centroids. * </pre> * * <pre> -M * Replace missing values with mean/mode. * </pre> * * <pre> -S <num> * Random number seed. * (default 10)</pre> * <!-- options-end --> * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.39 $ * @see RandomizableClusterer */public class SimpleKMeans extends RandomizableClusterer implements NumberOfClustersRequestable, WeightedInstancesHandler { /** for serialization */ static final long serialVersionUID = -3235809600124455376L; /** * replace missing values in training instances */ private ReplaceMissingValues m_ReplaceMissingFilter; /** * number of clusters to generate */ private int m_NumClusters = 2; /** * holds the cluster centroids */ private Instances m_ClusterCentroids; /** * Holds the standard deviations of the numeric attributes in each cluster */ private Instances m_ClusterStdDevs; /** * For each cluster, holds the frequency counts for the values of each * nominal attribute */ private int [][][] m_ClusterNominalCounts; private int[][] m_ClusterMissingCounts; /** * Stats on the full data set for comparison purposes */ private double[] m_FullMeansOrModes; private double[] m_FullStdDevs; private int[][] m_FullNominalCounts; private int[] m_FullMissingCounts; /** * Display standard deviations for numeric atts */ private boolean m_displayStdDevs; /** * Replace missing values globally? */ private boolean m_dontReplaceMissing = false; /** * The number of instances in each cluster */ private int [] m_ClusterSizes; /** * attribute min values */ private double [] m_Min; /** * attribute max values */ private double [] m_Max; /** * Keep track of the number of iterations completed before convergence */ private int m_Iterations = 0; /** * Holds the squared errors for all clusters */ private double [] m_squaredErrors; /** * the default constructor */ public SimpleKMeans() { super(); m_SeedDefault = 10; setSeed(m_SeedDefault); } /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Cluster data using the k means algorithm"; } /** * Returns default capabilities of the clusterer. * * @return the capabilities of this clusterer */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); return result; } /** * Generates a clusterer. Has to initialize all fields of the clusterer * that are not being set via options. * * @param data set of instances serving as training data * @throws Exception if the clusterer has not been * generated successfully */ public void buildClusterer(Instances data) throws Exception { // can clusterer handle the data? getCapabilities().testWithFail(data); m_Iterations = 0; m_ReplaceMissingFilter = new ReplaceMissingValues(); Instances instances = new Instances(data); instances.setClassIndex(-1); if (!m_dontReplaceMissing) { m_ReplaceMissingFilter.setInputFormat(instances); instances = Filter.useFilter(instances, m_ReplaceMissingFilter); } m_FullMeansOrModes = new double[instances.numAttributes()]; m_FullMissingCounts = new int[instances.numAttributes()]; if (m_displayStdDevs) { m_FullStdDevs = new double[instances.numAttributes()]; } m_FullNominalCounts = new int[instances.numAttributes()][0]; for (int i = 0; i < instances.numAttributes(); i++) { m_FullMissingCounts[i] = instances.attributeStats(i).missingCount; m_FullMeansOrModes[i] = instances.meanOrMode(i); if (instances.attribute(i).isNumeric()) { if (m_displayStdDevs) { m_FullStdDevs[i] = Math.sqrt(instances.variance(i)); } if (m_FullMissingCounts[i] == instances.numInstances()) { m_FullMeansOrModes[i] = Double.NaN; // mark missing as mean } } else { m_FullNominalCounts[i] = instances.attributeStats(i).nominalCounts; if (m_FullMissingCounts[i] > m_FullNominalCounts[i][Utils.maxIndex(m_FullNominalCounts[i])]) { m_FullMeansOrModes[i] = -1; // mark missing as most common value } } } m_Min = new double [instances.numAttributes()]; m_Max = new double [instances.numAttributes()]; for (int i = 0; i < instances.numAttributes(); i++) { m_Min[i] = m_Max[i] = Double.NaN; } m_ClusterCentroids = new Instances(instances, m_NumClusters); int[] clusterAssignments = new int [instances.numInstances()]; for (int i = 0; i < instances.numInstances(); i++) { updateMinMax(instances.instance(i)); } Random RandomO = new Random(getSeed()); int instIndex; HashMap initC = new HashMap(); DecisionTableHashKey hk = null; for (int j = instances.numInstances() - 1; j >= 0; j--) { instIndex = RandomO.nextInt(j+1); hk = new DecisionTableHashKey(instances.instance(instIndex), instances.numAttributes(), true); if (!initC.containsKey(hk)) { m_ClusterCentroids.add(instances.instance(instIndex)); initC.put(hk, null); } instances.swap(j, instIndex); if (m_ClusterCentroids.numInstances() == m_NumClusters) { break; } } m_NumClusters = m_ClusterCentroids.numInstances(); int i; boolean converged = false; int emptyClusterCount; Instances [] tempI = new Instances[m_NumClusters]; m_squaredErrors = new double [m_NumClusters]; m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][0]; m_ClusterMissingCounts = new int[m_NumClusters][instances.numAttributes()]; while (!converged) { emptyClusterCount = 0; m_Iterations++; converged = true; for (i = 0; i < instances.numInstances(); i++) { Instance toCluster = instances.instance(i); int newC = clusterProcessedInstance(toCluster, true); if (newC != clusterAssignments[i]) { converged = false; } clusterAssignments[i] = newC; } // update centroids m_ClusterCentroids = new Instances(instances, m_NumClusters); for (i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(instances, 0); } for (i = 0; i < instances.numInstances(); i++) { tempI[clusterAssignments[i]].add(instances.instance(i)); } for (i = 0; i < m_NumClusters; i++) { double [] vals = new double[instances.numAttributes()]; if (tempI[i].numInstances() == 0) { // empty cluster emptyClusterCount++; } else { for (int j = 0; j < instances.numAttributes(); j++) { vals[j] = tempI[i].meanOrMode(j); m_ClusterMissingCounts[i][j] = tempI[i].attributeStats(j).missingCount; m_ClusterNominalCounts[i][j] = tempI[i].attributeStats(j).nominalCounts; if (tempI[i].attribute(j).isNominal()) { if (m_ClusterMissingCounts[i][j] > m_ClusterNominalCounts[i][j][Utils.maxIndex(m_ClusterNominalCounts[i][j])]) { vals[j] = Instance.missingValue(); // mark mode as missing } } else { if (m_ClusterMissingCounts[i][j] == tempI[i].numInstances()) { vals[j] = Instance.missingValue(); // mark mean as missing } } } m_ClusterCentroids.add(new Instance(1.0, vals)); } } if (emptyClusterCount > 0) { m_NumClusters -= emptyClusterCount; if (converged) { Instances[] t = new Instances[m_NumClusters]; int index = 0; for (int k = 0; k < tempI.length; k++) { if (tempI[k].numInstances() > 0) { t[index++] = tempI[k]; } } tempI = t; } else { tempI = new Instances[m_NumClusters]; } } if (!converged) { m_squaredErrors = new double [m_NumClusters]; m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][0]; } } if (m_displayStdDevs) { m_ClusterStdDevs = new Instances(instances, m_NumClusters); } m_ClusterSizes = new int [m_NumClusters]; for (i = 0; i < m_NumClusters; i++) { if (m_displayStdDevs) { double [] vals2 = new double[instances.numAttributes()]; for (int j = 0; j < instances.numAttributes(); j++) { if (instances.attribute(j).isNumeric()) { vals2[j] = Math.sqrt(tempI[i].variance(j)); } else { vals2[j] = Instance.missingValue(); } } m_ClusterStdDevs.add(new Instance(1.0, vals2)); } m_ClusterSizes[i] = tempI[i].numInstances(); } } /** * clusters an instance that has been through the filters * * @param instance the instance to assign a cluster to * @param updateErrors if true, update the within clusters sum of errors * @return a cluster number */ private int clusterProcessedInstance(Instance instance, boolean updateErrors) { double minDist = Integer.MAX_VALUE; int bestCluster = 0; for (int i = 0; i < m_NumClusters; i++) { double dist = distance(instance, m_ClusterCentroids.instance(i)); if (dist < minDist) { minDist = dist; bestCluster = i; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -