seededkmeans.java

来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 2,032 行 · 第 1/5 页

JAVA
2,032
字号
	  }	  else {	    for (int j = 0; j < m_Instances.numAttributes(); j++) {	      globalValues[j] = m_Instances.meanOrMode(j); // uses usual meanOrMode	    }	  }	  // global centroid is dense in SPKMeans	  m_GlobalCentroid = new Instance(1.0, globalValues);	  m_GlobalCentroid.setDataset(m_Instances);	  if (m_Algorithm == ALGORITHM_SPHERICAL) {	    try {	      ((LearnableMetric)m_metric).normalizeInstanceWeighted(m_GlobalCentroid);		    }	    catch (Exception e) {	      e.printStackTrace();	    }	  }	  globalCentroidComputed = true;	  if (m_Verbose) {	    System.out.println("Global centroid is: " + m_GlobalCentroid);	  }	}	// randomPerturbInit	if (m_Verbose) {	  System.out.println("RandomPerturbInit seeding for centroid " + i);	}	for (int j = 0; j < m_Instances.numAttributes(); j++) {	  values[j] = m_GlobalCentroid.value(j) * (1 + m_DefaultPerturb * (random.nextFloat() - 0.5));	}	      }            // cluster centroids are dense in SPKMeans      m_ClusterCentroids.add(new Instance(1.0, values));      if (m_Algorithm == ALGORITHM_SPHERICAL) {	try {	  ((LearnableMetric) m_metric).normalizeInstanceWeighted(m_ClusterCentroids.instance(i));	}	catch (Exception e) {	  e.printStackTrace();	}      }    }      }  /** E-step of the KMeans clustering algorithm -- find best cluster assignments   */  protected void findBestAssignments() throws Exception{    m_Objective = 0;    int moved=0;    for (int i = 0; i < m_Instances.numInstances(); i++) {      m_currIdx = i;      Instance inst = m_Instances.instance(i);      boolean assigned = false;      // Constrained KMeans algorithm      if(m_SeedingMethod == SEEDING_CONSTRAINED) {	if (m_SeedHash == null) {	    	  System.err.println("Needs seed information for constrained SeededKMeans");	}	else if(m_SeedHash.containsKey(inst)) { // Seeded instances	  m_ClusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue(); 	  assigned = true;	  if (m_Verbose) {	    System.out.println("Assigning cluster " + m_ClusterAssignments[i] + " for seed instance " + i + ": " + inst);	  }	}      }      try {	if (!assigned) { // Unseeded instances	  int newAssignment = assignClusterToInstance(inst);	  if (newAssignment != m_ClusterAssignments[i]) {	    moved++;	    if (m_Verbose) {	      System.out.println("Reassigning instance " + i + " old cluster=" + m_ClusterAssignments[i] + " new cluster=" + newAssignment);	    }	  }	  m_ClusterAssignments[i] = newAssignment;	}	// Update objective function	if (!m_objFunDecreasing) { // objective function increases monotonically	  double newSimilarity = m_metric.similarity(inst, m_ClusterCentroids.instance(m_ClusterAssignments[i]));	  m_Objective += newSimilarity;	} 	else { // objective function decreases monotonically	  double newDistance = m_metric.distance(inst, m_ClusterCentroids.instance(m_ClusterAssignments[i]));	  m_Objective += newDistance * newDistance;	}      }       catch (Exception e) {	System.out.println("Could not find distance. Exception: " + e);	e.printStackTrace();      }    }        if(m_Verbose) {      System.out.println("\nAfter iteration " + m_Iterations + ":\n");      /*      for (int k=0; k<m_ClusterCentroids.numInstances(); k++) {	System.out.println ("  Centroid " + k + " is " + m_ClusterCentroids.instance(k));      }      */    }    System.out.println("Number of points moved in this E-step: " + moved);  }  /** M-step of the KMeans clustering algorithm -- updates cluster centroids   */  protected void updateClusterCentroids() {    // M-step: update cluster centroids    Instances [] tempI = new Instances[m_NumClusters];    m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);        for (int i = 0; i < m_NumClusters; i++) {      tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i    }    for (int i = 0; i < m_Instances.numInstances(); i++) {      tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i));      if (m_Verbose) {	System.out.println("Instance " + i + " added to cluster " + m_ClusterAssignments[i]);      }    }        // Calculates cluster centroids    for (int i = 0; i < m_NumClusters; i++) {      double [] values = new double[m_Instances.numAttributes()];      if (m_FastMode && isSparseInstance) {	values = meanOrMode(tempI[i]); // uses fast meanOrMode      }      else {	for (int j = 0; j < m_Instances.numAttributes(); j++) {	  values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode	}      }      // cluster centroids are dense in SPKMeans      m_ClusterCentroids.add(new Instance(1.0, values));      if (m_Algorithm == ALGORITHM_SPHERICAL) {	try {	  ((LearnableMetric) m_metric).normalizeInstanceWeighted(m_ClusterCentroids.instance(i));	}	catch (Exception e) {	  e.printStackTrace();	}      }    }  }  /** calculates objective function */  protected void calculateObjectiveFunction() throws Exception {    m_Objective = 0;    for (int i=0; i<m_Instances.numInstances(); i++) {      if (m_objFunDecreasing) {	double dist = m_metric.distance(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i]));	m_Objective += dist*dist;      }      else {	//m_Objective += similarity(i, m_ClusterAssignments[i]);	m_Objective += m_metric.similarity(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i]));      }    }  }    /**   * Generates a clusterer. Instances in data have to be   * either all sparse or all non-sparse   *   * @param data set of instances serving as training data    * @exception Exception if the clusterer has not been    * generated successfully   */  public void buildClusterer(Instances data) throws Exception {    setInstances(data);    // Don't rebuild the metric if it was already trained    if (!m_metricBuilt) {      m_metric.buildMetric(data);    }    m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);    m_ClusterAssignments = new int [m_Instances.numInstances()];    if (m_Verbose && m_SeedHash != null) {      System.out.println("Using seeding ...");    }    if (m_Instances.checkForNominalAttributes() && m_Instances.checkForStringAttributes()) {      throw new UnsupportedAttributeTypeException("Cannot handle nominal attributes\n");    }    initializeClusterer(); // Initializes cluster centroids (initial M-step)    System.out.println("Done initializing clustering ...");    getIndexClusters();    printIndexClusters();    if (m_Verbose) {      for (int i=0; i<m_NumClusters; i++) {	System.out.println("Centroid " + i + ": " + m_ClusterCentroids.instance(i));      }    }    boolean converged = false;    m_Iterations = 0;    double oldObjective = m_objFunDecreasing ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;    while (!converged) {      // E-step: updates m_Objective      System.out.println("Doing E-step ...");      findBestAssignments();      // M-step      System.out.println("Doing M-step ...");      updateClusterCentroids();       m_Iterations++;      calculateObjectiveFunction();      // Convergence check      if(Math.abs(oldObjective - m_Objective) > m_ObjFunConvergenceDifference) {	if (m_objFunDecreasing ? (oldObjective <  m_Objective) : (oldObjective >  m_Objective)) {	  converged = true;	  System.out.println("\nOSCILLATING, oldObjective=" + oldObjective + " newObjective=" + m_Objective);	  System.out.println("Seeding=" + m_Seedable + " SeedingMethod=" + m_SeedingMethod );	  	} else {	  converged = false;	  System.out.println("Objective function is: " + m_Objective);	}      }      else {	converged = true;	System.out.println("Old Objective function was: " + oldObjective);	System.out.println("Final Objective function is: " + m_Objective);      }      oldObjective = m_Objective;    }  }  public InstancePair[] bestPairsForActiveLearning(int numActive) throws Exception {    throw new Exception("Not implemented for SeededKMeans");  }  /** Returns the indices of the best numActive instances for active learning */  public int[] bestInstancesForActiveLearning(int numActive) throws Exception{    int numInstances = m_Instances.numInstances();    int [] clusterSizes = new int[m_NumClusters];    int [] activeLearningPoints = new int[numActive];    int [] clusterAssignments = new int[numInstances];    Instance [] sumOfClusterInstances = new Instance[m_NumClusters];    HashSet visitedPoints = new HashSet(numInstances);    boolean allClustersFound = false;    int numPointsSelected = 0;        // initialize clusterAssignments, clusterSizes, visitedPoints, sumOfClusterInstances    for (int i=0; i<numInstances; i++) {      Instance inst = m_Instances.instance(i);      if (m_SeedHash != null && m_SeedHash.containsKey(inst)) {	clusterAssignments[i] = ((Integer) m_SeedHash.get(inst)).intValue(); 	clusterSizes[clusterAssignments[i]]++;	visitedPoints.add(new Integer(i));	sumOfClusterInstances[clusterAssignments[i]] = sumWithInstance(sumOfClusterInstances[clusterAssignments[i]], inst);	if (m_Verbose) {	  //	  System.out.println("Init: adding point " + i + " to cluster " + clusterAssignments[i]);	}      }      else {	clusterAssignments[i] = -1;      }    }    // set allClustersFound    allClustersFound = setAllClustersFound(clusterSizes);    int totalPointsSpecified=0;    for (int i=0; i<m_NumClusters; i++) {      totalPointsSpecified += clusterSizes[i]; // HACK!!!    }    System.out.println("Total points specified: " + totalPointsSpecified + ", limit: " + m_ExtraPhase1RunFraction);    if (totalPointsSpecified < m_ExtraPhase1RunFraction) {      allClustersFound = false;    }        while (numPointsSelected < numActive) {      if (!allClustersFound) { // PHASE 1	System.out.println("In Phase 1");	// find next point, farthest from visited points	int nextPoint = farthestFromSet(visitedPoints, null);	if (nextPoint >= m_StartingIndexOfTest) {	  throw new Exception ("Test point " + nextPoint + " selected, something went wrong -- starting index of test is: " + m_StartingIndexOfTest);	}	visitedPoints.add(new Integer(nextPoint));	activeLearningPoints[numPointsSelected] = nextPoint;	numPointsSelected++;	// update cluster stats for this point	int classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue();	clusterAssignments[nextPoint] = classLabel;	clusterSizes[classLabel]++;	sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint));	// set allClustersFound	//	if (m_Verbose) {	System.out.println("Active learning point number: " + numPointsSelected + " is: " + nextPoint + ", with class label: " + classLabel);	  //	}	allClustersFound = setAllClustersFound(clusterSizes);	if (numPointsSelected >= numActive) {	  System.out.println("Out of queries before phase 1 extra loop. Queries so far: " + numPointsSelected);	  return activeLearningPoints; // go out of function	}	if (allClustersFound) {	  // Extra RUNS OF PHASE 1	  int [] tempClusterSizes = new int[m_NumClusters]; // temp cluster sizes	  boolean tempAllClustersFound = false;	  HashSet points = new HashSet(numInstances); // points visited in this farthest first loop	  points.add(new Integer(nextPoint)); // mark only last point as visited	  tempClusterSizes[classLabel]++; // update temp cluster sizes for this point	  HashSet eliminationSet = new HashSet(numInstances); // don't include these points in farthest first search	  for (int i=0; i<numInstances; i++) {	    Instance inst = m_Instances.instance(i);	    if (m_SeedHash != null && m_SeedHash.containsKey(inst)) {	      eliminationSet.add(new Integer(i)); // add labeled data to elimination set	    }	  }	  Iterator iter = visitedPoints.iterator();	  while(iter.hasNext()) {	    eliminationSet.add(iter.next()); // add already visited points to elim set	  }	  for (int i=0; i<m_ExtraPhase1RunFraction; i++) {	    System.out.println("Continuing Phase 1 run: " + i + " after all clusters visited");	    // find next point, farthest from points, eliminating points in eliminationSet	    nextPoint = farthestFromSet(points, eliminationSet);	    if (nextPoint >= m_StartingIndexOfTest) {	      throw new Exception ("Test point " + nextPoint + " selected, something went wrong -- starting index of test is: " + m_StartingIndexOfTest);	    }	    visitedPoints.add(new Integer(nextPoint)); // add to total set of visited points	    points.add(new Integer(nextPoint)); // add to points visited in this farthest first loop	    activeLearningPoints[numPointsSelected] = nextPoint;	    numPointsSelected++;	    // update cluster stats for this point	    classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue();	    clusterAssignments[nextPoint] = classLabel;	    clusterSizes[classLabel]++;	    sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint));	    tempClusterSizes[classLabel]++;	    //	if (m_Verbose) {	    System.out.println("Active learning point number: " + numPointsSelected + " is: " + nextPoint + ", with class label: " + classLabel);	    //	}	    tempAllClustersFound = setAllClustersFound(tempClusterSizes);	    if (tempAllClustersFound) { // found all clusters, reset local variables	      System.out.println("Resetting variables for next round of farthest first");	      tempClusterSizes = new int[m_NumClusters];	      tempAllClustersFound = false;	      Iterator tempIter = points.iterator();	      while(tempIter.hasNext()) {		eliminationSet.add((Integer) tempIter.next()); // add already visited points to elim set	      }	      points.clear(); // clear current set	      points.add(new Integer(nextPoint)); // add the last point	      tempClusterSizes[classLabel]++; // for the last point	    }	    if (numPointsSelected >= numActive) {	      System.out.println("Out of queries within phase 1 extra loop. Queries so far: " + numPointsSelected);	      return activeLearningPoints; // go out of function	    }	  }	}      }      else { // PHASE 2	// find smallest cluster	System.out.println("In Phase 2");	int smallestSize = Integer.MAX_VALUE, smallestCluster = -1;	for (int i=0; i<m_NumClusters; i++) {	  if (clusterSizes[i] < smallestSize) {	    smallestSize = clusterSizes[i];	    smallestCluster = i;	  }	}	if (m_Verbose) {	  System.out.println("Smallest cluster now: " + smallestCluster + ", with size: " + smallestSize);	}	// compute centroid of smallest cluster	Instance centroidOfSmallestCluster;	if (isSparseInstance) {	  centroidOfSmallestCluster = new SparseInstance(sumOfClusterInstances[smallestCluster]);	}	else {	  centroidOfSmallestCluster = new Instance(sumOfClusterInstances[smallestCluster]);	}	centroidOfSmallestCluster.setDataset(m_Instances);	if (!m_objFunDecreasing) {	  normalize(centroidOfSmallestCluster);	}	else {	  normalizeByWeight(centroidOfSmallestCluster);	}	// find next point, closest to centroid of smallest cluster	int nextPoint = nearestFromPoint(centroidOfSmallestCluster, visitedPoints);	if (nextPoint >= m_StartingIndexOfTest) {	  throw new Exception ("Test point selected, something went wrong!");	}	visitedPoints.add(new Integer(nextPoint));	activeLearningPoints[numPointsSelected] = nextPoint;	numPointsSelected++;	// update cluster stats for this point	int classLabel = (int) m_TotalTrainWithLabels.instance(nextPoint).classValue();	clusterAssignments[nextPoint] = classLabel;	clusterSizes[classLabel]++;	sumOfClusterInstances[classLabel] = sumWithInstance(sumOfClusterInstances[classLabel], m_Instances.instance(nextPoint));

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?