📄 simplekmeans.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* SimpleKMeans.java
* Copyright (C) 2000 Mark Hall
*
*/
package weka.clusterers;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.rules.DecisionTable;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
/**
* Simple k means clustering class.
*
* Valid options are:<p>
*
* -N <number of clusters> <br>
* Specify the number of clusters to generate. <p>
*
* -S <seed> <br>
* Specify random number seed. <p>
*
* @author Mark Hall (mhall@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision$
* @see Clusterer
* @see OptionHandler
*/
public class SimpleKMeans extends Clusterer
implements NumberOfClustersRequestable,
OptionHandler, WeightedInstancesHandler {
/**
*
*/
private static final long serialVersionUID = -3235809600124455376L;
/**
* replace missing values in training instances
*/
private ReplaceMissingValues m_ReplaceMissingFilter;
/**
* number of clusters to generate
*/
private int m_NumClusters = 2;
/**
* holds the cluster centroids
*/
private Instances m_ClusterCentroids;
/**
* Holds the standard deviations of the numeric attributes in each cluster
*/
private Instances m_ClusterStdDevs;
/**
* For each cluster, holds the frequency counts for the values of each
* nominal attribute
*/
private int [][][] m_ClusterNominalCounts;
/**
* The number of instances in each cluster
*/
private int [] m_ClusterSizes;
/**
* random seed
*/
private int m_Seed = 10;
/**
* attribute min values
*/
private double [] m_Min;
/**
* attribute max values
*/
private double [] m_Max;
/**
* Keep track of the number of iterations completed before convergence
*/
private int m_Iterations = 0;
private double [] m_squaredErrors;
/**
* Returns a string describing this clusterer
* @return a description of the evaluator suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Cluster data using the k means algorithm";
}
/**
* Generates a clusterer. Has to initialize all fields of the clusterer
* that are not being set via options.
*
* @param data set of instances serving as training data
* @exception Exception if the clusterer has not been
* generated successfully
*/
public void buildClusterer(Instances data) throws Exception {
m_Iterations = 0;
if (data.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
m_ReplaceMissingFilter = new ReplaceMissingValues();
m_ReplaceMissingFilter.setInputFormat(data);
Instances instances = Filter.useFilter(data, m_ReplaceMissingFilter);
m_Min = new double [instances.numAttributes()];
m_Max = new double [instances.numAttributes()];
for (int i = 0; i < instances.numAttributes(); i++) {
m_Min[i] = m_Max[i] = Double.NaN;
}
m_ClusterCentroids = new Instances(instances, m_NumClusters);
int[] clusterAssignments = new int [instances.numInstances()];
for (int i = 0; i < instances.numInstances(); i++) {
updateMinMax(instances.instance(i));
}
Random RandomO = new Random(m_Seed);
boolean [] selected = new boolean[instances.numInstances()];
int instIndex;
HashMap initC = new HashMap();
DecisionTable.hashKey hk = null;
// String hk = null;
boolean centroidSearchBailOut = false;
int i;
for (i = 0; i < m_NumClusters; i++) {
int centroidCount = 0;
do {
instIndex = RandomO.nextInt(instances.numInstances());
hk = new DecisionTable.hashKey(instances.instance(instIndex),
instances.numAttributes());
// hk = instances.instance(instIndex).toString();
if (initC.containsKey(hk)) {
// if (initC.containsValue(instances.instance(instIndex))) {
if (!selected[instIndex]) {
centroidCount++;
selected[instIndex] = true;
}
}
} while (selected[instIndex] && centroidCount < instances.numInstances());
if (centroidCount >= instances.numInstances()) {
// bail out and set the number of requested clusters to i
centroidSearchBailOut = true;
break;
}
m_ClusterCentroids.add(instances.instance(instIndex));
initC.put(hk, null);
selected[instIndex] = true;
centroidCount++;
}
selected = null;
if (centroidSearchBailOut) {
m_NumClusters = i;
}
boolean converged = false;
int emptyClusterCount;
Instances [] tempI = new Instances[m_NumClusters];
m_squaredErrors = new double [m_NumClusters];
m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][0];
while (!converged) {
emptyClusterCount = 0;
m_Iterations++;
converged = true;
for (i = 0; i < instances.numInstances(); i++) {
Instance toCluster = instances.instance(i);
int newC = clusterProcessedInstance(toCluster);
if (newC != clusterAssignments[i]) {
converged = false;
}
clusterAssignments[i] = newC;
// System.out.println(newC);
}
// update centroids
m_ClusterCentroids = new Instances(instances, m_NumClusters);
for (i = 0; i < m_NumClusters; i++) {
tempI[i] = new Instances(instances, 0);
}
for (i = 0; i < instances.numInstances(); i++) {
tempI[clusterAssignments[i]].add(instances.instance(i));
}
for (i = 0; i < m_NumClusters; i++) {
double [] vals = new double[instances.numAttributes()];
if (tempI[i].numInstances() == 0) {
// empty cluster
emptyClusterCount++;
} else {
for (int j = 0; j < instances.numAttributes(); j++) {
vals[j] = tempI[i].meanOrMode(j);
m_ClusterNominalCounts[i][j] =
tempI[i].attributeStats(j).nominalCounts;
}
m_ClusterCentroids.add(new Instance(1.0, vals));
}
}
if (emptyClusterCount > 0) {
m_NumClusters -= emptyClusterCount;
tempI = new Instances[m_NumClusters];
}
if (!converged) {
m_squaredErrors = new double [m_NumClusters];
m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][0];
}
}
m_ClusterStdDevs = new Instances(instances, m_NumClusters);
m_ClusterSizes = new int [m_NumClusters];
for (i = 0; i < m_NumClusters; i++) {
double [] vals2 = new double[instances.numAttributes()];
for (int j = 0; j < instances.numAttributes(); j++) {
if (instances.attribute(j).isNumeric()) {
vals2[j] = Math.sqrt(tempI[i].variance(j));
} else {
vals2[j] = Instance.missingValue();
}
}
m_ClusterStdDevs.add(new Instance(1.0, vals2));
m_ClusterSizes[i] = tempI[i].numInstances();
}
}
/**
* clusters an instance that has been through the filters
*
* @param instance the instance to assign a cluster to
* @return a cluster number
*/
private int clusterProcessedInstance(Instance instance) {
double minDist = Integer.MAX_VALUE;
int bestCluster = 0;
for (int i = 0; i < m_NumClusters; i++) {
double dist = distance(instance, m_ClusterCentroids.instance(i));
if (dist < minDist) {
minDist = dist;
bestCluster = i;
}
}
m_squaredErrors[bestCluster] += minDist;
return bestCluster;
}
/**
* Classifies a given instance.
*
* @param instance the instance to be assigned to a cluster
* @return the number of the assigned cluster as an interger
* if the class is emerated, otherwise the predicted value
* @exception Exception if instance could not be classified
* successfully
*/
public int clusterInstance(Instance instance) throws Exception {
m_ReplaceMissingFilter.input(instance);
m_ReplaceMissingFilter.batchFinished();
Instance inst = m_ReplaceMissingFilter.output();
return clusterProcessedInstance(inst);
}
/**
* Calculates the distance between two instances
*
* @param test the first instance
* @param train the second instance
* @return the distance between the two given instances, between 0 and 1
*/
private double distance(Instance first, Instance second) {
double distance = 0;
int firstI, secondI;
for (int p1 = 0, p2 = 0;
p1 < first.numValues() || p2 < second.numValues();) {
if (p1 >= first.numValues()) {
firstI = m_ClusterCentroids.numAttributes();
} else {
firstI = first.index(p1);
}
if (p2 >= second.numValues()) {
secondI = m_ClusterCentroids.numAttributes();
} else {
secondI = second.index(p2);
}
if (firstI == m_ClusterCentroids.classIndex()) {
p1++; continue;
}
if (secondI == m_ClusterCentroids.classIndex()) {
p2++; continue;
}
double diff;
if (firstI == secondI) {
diff = difference(firstI,
first.valueSparse(p1),
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -