seededkmeans.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 2,032 行 · 第 1/5 页
JAVA
2,032 行
} else { for (int j = 0; j < m_Instances.numAttributes(); j++) { globalValues[j] = m_Instances.meanOrMode(j); // uses usual meanOrMode } } // global centroid is dense in SPKMeans m_GlobalCentroid = new Instance(1.0, globalValues); m_GlobalCentroid.setDataset(m_Instances); if (m_Algorithm == ALGORITHM_SPHERICAL) { try { ((LearnableMetric)m_metric).normalizeInstanceWeighted(m_GlobalCentroid); } catch (Exception e) { e.printStackTrace(); } } globalCentroidComputed = true; if (m_Verbose) { System.out.println("Global centroid is: " + m_GlobalCentroid); } } // randomPerturbInit if (m_Verbose) { System.out.println("RandomPerturbInit seeding for centroid " + i); } for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = m_GlobalCentroid.value(j) * (1 + m_DefaultPerturb * (random.nextFloat() - 0.5)); } } // cluster centroids are dense in SPKMeans m_ClusterCentroids.add(new Instance(1.0, values)); if (m_Algorithm == ALGORITHM_SPHERICAL) { try { ((LearnableMetric) m_metric).normalizeInstanceWeighted(m_ClusterCentroids.instance(i)); } catch (Exception e) { e.printStackTrace(); } } } } /** E-step of the KMeans clustering algorithm -- find best cluster assignments */ protected void findBestAssignments() throws Exception{ m_Objective = 0; int moved=0; for (int i = 0; i < m_Instances.numInstances(); i++) { m_currIdx = i; Instance inst = m_Instances.instance(i); boolean assigned = false; // Constrained KMeans algorithm if(m_SeedingMethod == SEEDING_CONSTRAINED) { if (m_SeedHash == null) { System.err.println("Needs seed information for constrained SeededKMeans"); } else if(m_SeedHash.containsKey(inst)) { // Seeded instances m_ClusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue(); assigned = true; if (m_Verbose) { System.out.println("Assigning cluster " + m_ClusterAssignments[i] + " for seed instance " + i + ": " + inst); } } } try { if (!assigned) { // Unseeded instances int newAssignment = assignClusterToInstance(inst); if (newAssignment != m_ClusterAssignments[i]) { moved++; if (m_Verbose) { System.out.println("Reassigning instance " + i + " old cluster=" + m_ClusterAssignments[i] + " new cluster=" + newAssignment); } } m_ClusterAssignments[i] = newAssignment; } // Update objective function if (!m_objFunDecreasing) { // objective function increases monotonically double newSimilarity = m_metric.similarity(inst, m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += newSimilarity; } else { // objective function decreases monotonically double newDistance = m_metric.distance(inst, m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += newDistance * newDistance; } } catch (Exception e) { System.out.println("Could not find distance. Exception: " + e); e.printStackTrace(); } } if(m_Verbose) { System.out.println("\nAfter iteration " + m_Iterations + ":\n"); /* for (int k=0; k<m_ClusterCentroids.numInstances(); k++) { System.out.println (" Centroid " + k + " is " + m_ClusterCentroids.instance(k)); } */ } System.out.println("Number of points moved in this E-step: " + moved); } /** M-step of the KMeans clustering algorithm -- updates cluster centroids */ protected void updateClusterCentroids() { // M-step: update cluster centroids Instances [] tempI = new Instances[m_NumClusters]; m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); for (int i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i } for (int i = 0; i < m_Instances.numInstances(); i++) { tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i)); if (m_Verbose) { System.out.println("Instance " + i + " added to cluster " + m_ClusterAssignments[i]); } } // Calculates cluster centroids for (int i = 0; i < m_NumClusters; i++) { double [] values = new double[m_Instances.numAttributes()]; if (m_FastMode && isSparseInstance) { values = meanOrMode(tempI[i]); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode } } // cluster centroids are dense in SPKMeans m_ClusterCentroids.add(new Instance(1.0, values)); if (m_Algorithm == ALGORITHM_SPHERICAL) { try { ((LearnableMetric) m_metric).normalizeInstanceWeighted(m_ClusterCentroids.instance(i)); } catch (Exception e) { e.printStackTrace(); } } } } /** calculates objective function */ protected void calculateObjectiveFunction() throws Exception { m_Objective = 0; for (int i=0; i<m_Instances.numInstances(); i++) { if (m_objFunDecreasing) { double dist = m_metric.distance(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += dist*dist; } else { //m_Objective += similarity(i, m_ClusterAssignments[i]); m_Objective += m_metric.similarity(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i])); } } } /** * Generates a clusterer. Instances in data have to be * either all sparse or all non-sparse * * @param data set of instances serving as training data * @exception Exception if the clusterer has not been * generated successfully */ public void buildClusterer(Instances data) throws Exception { setInstances(data); // Don't rebuild the metric if it was already trained if (!m_metricBuilt) { m_metric.buildMetric(data); } m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); m_ClusterAssignments = new int [m_Instances.numInstances()]; if (m_Verbose && m_SeedHash != null) { System.out.println("Using seeding ..."); } if (m_Instances.checkForNominalAttributes() && m_Instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle nominal attributes\n"); } initializeClusterer(); // Initializes cluster centroids (initial M-step) System.out.println("Done initializing clustering ..."); getIndexClusters(); printIndexClusters(); if (m_Verbose) { for (int i=0; i<m_NumClusters; i++) { System.out.println("Centroid " + i + ": " + m_ClusterCentroids.instance(i)); } } boolean converged = false; m_Iterations = 0; double oldObjective = m_objFunDecreasing ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; while (!converged) { // E-step: updates m_Objective System.out.println("Doing E-step ..."); findBestAssignments(); // M-step System.out.println("Doing M-step ..."); updateClusterCentroids(); m_Iterations++; calculateObjectiveFunction(); // Convergence check if(Math.abs(oldObjective - m_Objective) > m_ObjFunConvergenceDifference) { if (m_objFunDecreasing ? (oldObjective < m_Objective) : (oldObjective > m_Objective)) { converged = true; System.out.println("\nOSCILLATING, oldObjective=" + oldObjective + " newObjective=" + m_Objective); System.out.println("Seeding=" + m_Seedable + " SeedingMethod=" + m_SeedingMethod ); } else { converged = false; System.out.println("Objective function is: " + m_Objective); } } else { converged = true; System.out.println("Old Objective function was: " + oldObjective); System.out.println("Final Objective function is: " + m_Objective); } oldObjective = m_Objective; } } public InstancePair[] bestPairsForActiveLearning(int numActive) throws Exception { throw new Exception("Not implemented for SeededKMeans"); } /** Returns the indices of the best numActive instances for active learning */ public int[] bestInstancesForActiveLearning(int numActive) throws Exception{ int numInstances = m_Instances.numInstances(); int [] clusterSizes = new int[m_NumClusters]; int [] activeLearningPoints = new int[numActive]; int [] clusterAssignments = new int[numInstances]; Instance [] sumOfClusterInstances = new Instance[m_NumClusters]; HashSet visitedPoints = new HashSet(numInstances); boolean allClustersFound = false; int numPointsSelected = 0; // initialize clusterAssignments, clusterSizes, visitedPoints, sumOfClusterInstances for (int i=0; i<numInstances; i++) { Instance inst = m_Instances.instance(i); if (m_SeedHash != null && m_SeedHash.containsKey(inst)) { clusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue(); clusterSizes[clusterAssignments[i]]++; visitedPoints.add(new Integer(i)); sumOfClusterInstances[clusterAssignments[i]] = sumWithInstance(sumOfClusterInstances[clusterAssignments[i]], inst); if (m_Verbose) { // System.out.println("Init: adding point " + i + " to cluster " + clusterAssignments[i]); } } else { clusterAssignments[i] = -1; } } // set allClustersFound allClustersFound = setAllClustersFound(clusterSizes); int totalPointsSpecified=0; for (int i=0; i<m_NumClusters; i++) { totalPointsSpecified += clusterSizes[i]; // HACK!!! } System.out.println("Total points specified: " + totalPointsSpecified + ", limit: " + m_ExtraPhase1RunFraction); if (totalPointsSpecified < m_ExtraPhase1RunFraction) { allClustersFound = false; } while (numPointsSelected < numActive) { if (!allClustersFound) { // PHASE 1 System.out.println("In Phase 1"); // find next point, farthest from visited points int nextPoint = farthestFromSet(visitedPoints, null); if (nextPoint >= m_StartingIndexOfTest) { throw new Exception ("Test point " + nextPoint + " selected, something went wrong -- starting index of test is: " + m_StartingIndexOfTest); } visitedPoints.add(new Integer(nextPoint)); activeLearningPoints[numPointsSelected] = nextPoint; numPointsSelected++; // update cluster stats for this point int classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue(); clusterAssignments[nextPoint] = classLabel; clusterSizes[classLabel]++; sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint)); // set allClustersFound // if (m_Verbose) { System.out.println("Active learning point number: " + numPointsSelected + " is: " + nextPoint + ", with class label: " + classLabel); // } allClustersFound = setAllClustersFound(clusterSizes); if (numPointsSelected >= numActive) { System.out.println("Out of queries before phase 1 extra loop. Queries so far: " + numPointsSelected); return activeLearningPoints; // go out of function } if (allClustersFound) { // Extra RUNS OF PHASE 1 int [] tempClusterSizes = new int[m_NumClusters]; // temp cluster sizes boolean tempAllClustersFound = false; HashSet points = new HashSet(numInstances); // points visited in this farthest first loop points.add(new Integer(nextPoint)); // mark only last point as visited tempClusterSizes[classLabel]++; // update temp cluster sizes for this point HashSet eliminationSet = new HashSet(numInstances); // don't include these points in farthest first search for (int i=0; i<numInstances; i++) { Instance inst = m_Instances.instance(i); if (m_SeedHash != null && m_SeedHash.containsKey(inst)) { eliminationSet.add(new Integer(i)); // add labeled data to elimination set } } Iterator iter = visitedPoints.iterator(); while(iter.hasNext()) { eliminationSet.add(iter.next()); // add already visited points to elim set } for (int i=0; i<m_ExtraPhase1RunFraction; i++) { System.out.println("Continuing Phase 1 run: " + i + " after all clusters visited"); // find next point, farthest from points, eliminating points in eliminationSet nextPoint = farthestFromSet(points, eliminationSet); if (nextPoint >= m_StartingIndexOfTest) { throw new Exception ("Test point " + nextPoint + " selected, something went wrong -- starting index of test is: " + m_StartingIndexOfTest); } visitedPoints.add(new Integer(nextPoint)); // add to total set of visited points points.add(new Integer(nextPoint)); // add to points visited in this farthest first loop activeLearningPoints[numPointsSelected] = nextPoint; numPointsSelected++; // update cluster stats for this point classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue(); clusterAssignments[nextPoint] = classLabel; clusterSizes[classLabel]++; sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint)); tempClusterSizes[classLabel]++; // if (m_Verbose) { System.out.println("Active learning point number: " + numPointsSelected + " is: " + nextPoint + ", with class label: " + classLabel); // } tempAllClustersFound = setAllClustersFound(tempClusterSizes); if (tempAllClustersFound) { // found all clusters, reset local variables System.out.println("Resetting variables for next round of farthest first"); tempClusterSizes = new int[m_NumClusters]; tempAllClustersFound = false; Iterator tempIter = points.iterator(); while(tempIter.hasNext()) { eliminationSet.add((Integer) tempIter.next()); // add already visited points to elim set } points.clear(); // clear current set points.add(new Integer(nextPoint)); // add the last point tempClusterSizes[classLabel]++; // for the last point } if (numPointsSelected >= numActive) { System.out.println("Out of queries within phase 1 extra loop. Queries so far: " + numPointsSelected); return activeLearningPoints; // go out of function } } } } else { // PHASE 2 // find smallest cluster System.out.println("In Phase 2"); int smallestSize = Integer.MAX_VALUE, smallestCluster = -1; for (int i=0; i<m_NumClusters; i++) { if (clusterSizes[i] < smallestSize) { smallestSize = clusterSizes[i]; smallestCluster = i; } } if (m_Verbose) { System.out.println("Smallest cluster now: " + smallestCluster + ", with size: " + smallestSize); } // compute centroid of smallest cluster Instance centroidOfSmallestCluster; if (isSparseInstance) { centroidOfSmallestCluster = new SparseInstance(sumOfClusterInstances[smallestCluster]); } else { centroidOfSmallestCluster = new Instance(sumOfClusterInstances[smallestCluster]); } centroidOfSmallestCluster.setDataset(m_Instances); if (!m_objFunDecreasing) { normalize(centroidOfSmallestCluster); } else { normalizeByWeight(centroidOfSmallestCluster); } // find next point, closest to centroid of smallest cluster int nextPoint = nearestFromPoint(centroidOfSmallestCluster, visitedPoints); if (nextPoint >= m_StartingIndexOfTest) { throw new Exception ("Test point selected, something went wrong!"); } visitedPoints.add(new Integer(nextPoint)); activeLearningPoints[numPointsSelected] = nextPoint; numPointsSelected++; // update cluster stats for this point int classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue(); clusterAssignments[nextPoint] = classLabel; clusterSizes[classLabel]++; sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint));
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?