📄 kmeansclusterer.java
字号:
String msg = "Number of iterations must be non-negative." + " Found maxIterations=" + maxIterations; throw new IllegalArgumentException(msg); } mFeatureExtractor = featureExtractor; mNumClusters = numClusters; mMaxIterations = maxIterations; } /** * Returns the feature extractor for this clusterer. * * @return The feature extractor for this clusterer. */ public FeatureExtractor<E> featureExtractor() { return mFeatureExtractor; } /** * Returns the number of clusters this clusterer will return. * This is the "<code>k</code>" in "k-means". * * @return The number of clusters this clusterer will return. */ public int numClusters() { return mNumClusters; } /** * Recluster the specified clustering using up to the specified * number of k-means iterations. This method allows users to * specify their own initial clusterings, which are then reallocated * using the standard k-means algorithm. * * @param clustering Clustering to recluster. * @param maxIterations Maximum number of reclustering iterations. * @return New clustering of input elements. */ public Set<Set<E>> recluster(Set<Set<E>> clustering, int maxIterations) { int numElements = numElements(clustering); Object[] elements = new Object[numElements]; Map<String,? extends Number>[] featureMaps = (Map<String,? extends Number>[]) new Map[numElements]; ObjectToDoubleMap<String>[] centroids = (ObjectToDoubleMap<String>[]) new ObjectToDoubleMap[clustering.size()]; for (int i = 0; i < centroids.length; ++i) centroids[i] = new ObjectToDoubleMap<String>(); List<Integer>[] clusterElements = (List<Integer>[]) new List[clustering.size()]; for (int i = 0; i < clusterElements.length; ++i) clusterElements[i] = new ArrayList<Integer>(); int eltIndex = 0; int clusterIndex = 0; for (Set<E> cluster : clustering) { for (E e : cluster) { elements[eltIndex] = e; featureMaps[eltIndex] = mFeatureExtractor.features(e); add(centroids[clusterIndex],featureMaps[eltIndex]); clusterElements[clusterIndex].add(new Integer(eltIndex)); ++eltIndex; } ++clusterIndex; } // loop performs scaling of centroids based on # of elts return clusterIterations(centroids, clusterElements, elements, featureMaps, maxIterations); } int numElements(Set<Set<E>> clustering) { int count = 0; for (Set<E> cluster : clustering) count += cluster.size(); return count; } /** * Return a k-means clustering of the specified set of elements. * Note that this method is randomized and may return different * results over different runs. See the class documentation for * more details. * * @param elementSet Set of elements to cluster. * @return Clustering of the specified elements. */ public Set<Set<E>> cluster(Set<? extends E> elementSet) { // handle small input if (elementSet.size() <= mNumClusters) { Set<Set<E>> clustering = new HashSet<Set<E>>((3 * elementSet.size()) / 2); for (E elt : elementSet) { Set<E> cluster = SmallSet.<E>create(elt); clustering.add(cluster); } return clustering; } // randomly ordered elements Object[] elements = new Object[elementSet.size()]; elementSet.toArray(elements); Arrays.<Object>permute(elements); // parallel array of extracted features Map<String,? extends Number>[] featureMaps = (Map<String,? extends Number>[]) new Map[elements.length]; for (int i = 0; i < featureMaps.length; ++i) featureMaps[i] = mFeatureExtractor.features((E)elements[i]); // initial centroids (uses randomness of elements) ObjectToDoubleMap<String>[] centroids = createCentroids(); List<Integer>[] clusterElements = createClusterElements(); for (int i = 0; i < elements.length; ++i) { int clusterId = i % mNumClusters; add(centroids[clusterId],featureMaps[i]); clusterElements[clusterId].add(new Integer(i)); } return clusterIterations(centroids,clusterElements, elements,featureMaps, mMaxIterations); } Set<Set<E>> clusterIterations(ObjectToDoubleMap<String>[] centroids, List<Integer>[] clusterElements, Object[] elements, Map<String,? extends Number>[] featureMaps, int maxIterations) { // iterate until fixed for (int iteration = 0; iteration < maxIterations; ++iteration) { scale(centroids,clusterElements); // always unscaled coming in // printCentroids(centroids); ObjectToDoubleMap<String>[] nextCentroids = createCentroids(); List<Integer>[] nextClusterElements = createClusterElements(); boolean fixed = true; for (int i = 0; i < mNumClusters; ++i) { List<Integer> cluster = clusterElements[i]; for (int k = 0; k < cluster.size(); ++k) { Integer eltIndexInt = cluster.get(k); int eltIndex = eltIndexInt.intValue(); double closestDistance = Double.POSITIVE_INFINITY; int closestIndex = -1; for (int j = 0; j < mNumClusters; ++j) { double distance = euclideanDistance(centroids[j], featureMaps[eltIndex]); if (distance < closestDistance) { closestDistance = distance; closestIndex = j; } } // printFeaturesClosest(featureMaps[eltIndex],closestIndex); if (closestIndex == -1) closestIndex = 0; // or assign at random add(nextCentroids[closestIndex], featureMaps[eltIndex]); nextClusterElements[closestIndex].add(eltIndexInt); if (closestIndex != i) fixed = false; } } if (fixed) break; centroids = nextCentroids; clusterElements = nextClusterElements; } Set<Set<E>> clustering = new HashSet<Set<E>>((3 * mNumClusters) / 2); for (int i = 0; i < mNumClusters; ++i) { HashSet<E> cluster = new HashSet<E>(); for (Integer k : clusterElements[i]) cluster.add((E)elements[k.intValue()]); if (cluster.size() > 0) clustering.add(cluster); } return clustering; } void scale(ObjectToDoubleMap<String>[] centroids, List<Integer>[] nextClusterElements) { for (int i = 0; i < centroids.length; ++i) { double numElts = nextClusterElements[i].size(); ObjectToDoubleMap<String> centroid = centroids[i]; for (String s : centroid.keySet()) centroid.set(s, centroid.getValue(s) / numElts); } } void printFeaturesClosest(Map<String,? extends Number> featureMap, int closestIndex) { System.out.println(" features=" + featureMap.toString().trim().replaceAll("\n",", ")); System.out.println(" closest centroid=" + closestIndex); } void printCentroids(ObjectToDoubleMap<String>[] centroids) { System.out.println("\nCentroids"); for (int i = 0; i < centroids.length; ++i) System.out.println(" " + i + " " + centroids[i]); } void scale(ObjectToDoubleMap<String> centroid, double scalar) { } List<Integer>[] createClusterElements() { List<Integer>[] clusterElements = (List<Integer>[]) new List[mNumClusters]; for (int i = 0; i < clusterElements.length; ++i) clusterElements[i] = new ArrayList<Integer>(); return clusterElements; } ObjectToDoubleMap<String>[] createCentroids() { ObjectToDoubleMap<String>[] centroids = (ObjectToDoubleMap<String>[]) new ObjectToDoubleMap[mNumClusters]; for (int i = 0; i < centroids.length; ++i) centroids[i] = new ObjectToDoubleMap<String>(); return centroids; } static void add(ObjectToDoubleMap<String> centroid, Map<String,? extends Number> featureMap) { for (Map.Entry<String,? extends Number> entry : featureMap.entrySet()) centroid.increment(entry.getKey(), entry.getValue().doubleValue()); } static double euclideanDistance(ObjectToDoubleMap<String> centroid, Map<String,? extends Number> featureMap) { double sqDist = 0.0; for (Map.Entry<String,? extends Number> featureEntry : featureMap.entrySet()) { double diff = featureEntry.getValue().doubleValue() - centroid.getValue(featureEntry.getKey()); sqDist += diff * diff; } for (Map.Entry<String,Double> centroidEntry : centroid.entrySet()) { if (featureMap.containsKey(centroidEntry.getKey())) continue; double diff = centroidEntry.getValue().doubleValue(); sqDist += diff * diff; } return sqDist; // montonic related to dist = Math.sqrt(sqDist) }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -