⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 .#mpckmeans.java.1.110

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 110
📖 第 1 页 / 共 5 页
字号:
      m_objNormalizerCurrPoint = -m_logTerms[centroidIdx];     }    // only add the constraints if seedable or constrained    if (m_Seedable || (m_Trainable != TRAINING_NONE)) {         Object list =  m_instanceConstraintHash.get(new Integer(instIdx));      if (list != null) {   // there are constraints associated with this instance	ArrayList constraintList = (ArrayList) list;	for (int i = 0; i < constraintList.size(); i++) {	  InstancePair pair = (InstancePair) constraintList.get(i);	  int firstIdx = pair.first;	  int secondIdx = pair.second;	  Instance instance1 = m_Instances.instance(firstIdx);	  Instance instance2 = m_Instances.instance(secondIdx);	  int otherIdx = (firstIdx == instIdx) ? 	    m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];	  	  // check whether the constraint is violated	  if (otherIdx != -1 && otherIdx < m_NumClusters) { 	    if (otherIdx != centroidIdx && 		pair.linkType == InstancePair.MUST_LINK) {	      // split penalty in half between the two involved clusters	      if (m_useMultipleMetrics) {  		double penalty1 = m_metrics[otherIdx].penaltySymmetric(instance1, instance2);		double penalty2 = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2);		m_objMustLinksCurrPoint += 0.5 * m_MLweight * (penalty1 + penalty2);	      } else {		double penalty = m_metric.penaltySymmetric(instance1, instance2);		m_objMustLinksCurrPoint += m_MLweight * penalty;	      }	    } else if (otherIdx == centroidIdx &&		       pair.linkType == InstancePair.CANNOT_LINK) { 	      double penalty = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2);	      m_objCannotLinksCurrPoint +=  m_CLweight *		(m_maxCLPenalties[centroidIdx] - penalty);	      if (m_maxCLPenalties[centroidIdx] - penalty < 0) {		System.out.println("***NEGATIVE*** penalty: " + penalty + " for CL constraint"); 	      }	    }	  }	}      }    }    double total = m_objVarianceCurrPoint  + m_objCannotLinksCurrPoint       + m_objMustLinksCurrPoint + m_objNormalizerCurrPoint;    if(m_verbose) {      System.out.println("Final penalty for instance " + instIdx + " and centroid "			 + centroidIdx + " is: " + total);    }    return total;  }    /** M-step of the KMeans clustering algorithm -- updates cluster centroids   */  protected void updateClusterCentroids() throws Exception {    //    System.out.println("CENTROIDS BEFORE: " + m_ClusterCentroids);        // M-step: update cluster centroids    Instances [] tempI = new Instances[m_NumClusters];    Instances tempCentroids = m_ClusterCentroids;    Instances tempNewCentroids = new Instances(m_Instances, m_NumClusters);     m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);        for (int i = 0; i < m_NumClusters; i++) {      tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i    }    for (int i = 0; i < m_Instances.numInstances(); i++) {      tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i));    }        // Calculates cluster centroids    for (int i = 0; i < m_NumClusters; i++) {      double [] values = new double[m_Instances.numAttributes()];             if (m_isSparseInstance) {	values = ClusterUtils.meanOrMode(tempI[i]); // uses fast meanOrMode      }      else {	for (int j = 0; j < m_Instances.numAttributes(); j++) {	  values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode	}      }      // cluster centroids are dense in SPKMeans      Instance centroid;       if (m_isSparseInstance) {	centroid = new SparseInstance(1.0, values);       } else { // non-sparse	centroid = new Instance(1.0, values);      }      // if we are using a smoothing metric, smooth the centroids      if (m_metric instanceof SmoothingMetric &&	  ((SmoothingMetric) m_metric).getUseSmoothing()) { 	SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric;	centroid = smoothingMetric.smoothInstance(centroid); 	System.out.println("Using Smoothing ...");      }      m_ClusterCentroids.add(centroid); //        tempNewCentroids.add(centroid);//        m_ClusterCentroids.delete(); //        for (int j = 0; j <= i; j++) {//  	m_ClusterCentroids.add(tempNewCentroids.instance(j));//        }//        for (int j = i+1; j < m_NumClusters; j++) {//  	m_ClusterCentroids.add(tempCentroids.instance(j));//        } //        double objBackup = m_Objective;//        System.out.println(calculateObjectiveFunction(false));//        m_Objective = objBackup;            // in SPKMeans, cluster centroids need to be normalized      if (m_metric.doesNormalizeData()) {	m_metric.normalizeInstanceWeighted(m_ClusterCentroids.instance(i));      }    }    if (m_metric instanceof SmoothingMetric &&	((SmoothingMetric) m_metric).getUseSmoothing())      updateSmoothingMetrics();               for (int i = 0; i < m_NumClusters; i++)      tempI[i] = null; // free memory    //    System.out.println("CENTROIDS AFTER: " + m_ClusterCentroids);   }  /** M-step of the KMeans clustering algorithm -- updates metric   *  weights. Invoked only when we're using non-Potts model   *  and metric is trainable   */  protected void updateMetricWeights() throws Exception {    if (m_useMultipleMetrics) {      for (int i = 0; i < m_NumClusters; i++) {	m_metricLearners[i].trainMetric(i);      }     } else {      m_metricLearner.trainMetric(-1);    }     InitNormalizerRegularizer();  }   /** checks for convergence */  public boolean convergenceCheck(double oldObjective,				  double newObjective) throws Exception {    boolean converged = false;    // Convergence check    if(Math.abs(oldObjective - newObjective) < m_ObjFunConvergenceDifference) {      System.out.println("Final objective function is: " + newObjective);      converged = true;    }    // number of iterations check    if (m_numBlankIterations >= m_maxBlankIterations) {      System.out.println("Max blank iterations reached ...\n");      converged = true;    }    if (m_Iterations >= m_maxIterations) {      System.out.println("Max iterations reached ...\n");      converged = true;    }    return converged;  }  /** calculates objective function */  public double calculateObjectiveFunction(boolean isComplete) throws Exception {    if (m_verbose) {      System.out.println("Calculating objective function ...");    }    // update the oldObjective only if previous estimate of m_Objective    // was complete    if (isComplete) {      m_OldObjective = m_Objective;    }    m_Objective = 0;    m_objVariance = 0;    m_objMustLinks = 0;    m_objCannotLinks = 0;    m_objNormalizer = 0;    // temporarily halve weights since every constraint is counted twice    double tempML = m_MLweight;    double tempCL = m_CLweight;    m_MLweight = tempML/2;    m_CLweight = tempCL/2;         if (m_verbose) {      System.out.println("Must link weight: " + m_MLweight);      System.out.println("Cannot link weight: " + m_CLweight);        }    for (int i=0; i<m_Instances.numInstances(); i++) {      if (m_isOfflineMetric) {	double dist = m_metric.penalty(m_Instances.instance(i),				       m_ClusterCentroids.instance(m_ClusterAssignments[i]));	m_Objective += dist;	if (m_verbose) {	  System.out.println("Component for " + i + " = " + dist);	}      }      else {	m_Objective += penaltyForInstance(i, m_ClusterAssignments[i]);	m_objVariance += m_objVarianceCurrPoint;	m_objMustLinks += m_objMustLinksCurrPoint;	m_objCannotLinks += m_objCannotLinksCurrPoint;	m_objNormalizer += m_objNormalizerCurrPoint;      }    }    m_Objective -= m_objRegularizer;    m_MLweight = tempML;    m_CLweight = tempCL; // reset the values of the constraint weights    // Oscillation check    if ((float)m_OldObjective < (float)m_Objective) {      System.out.println("WHOA!!!  Oscillations => bug in EM step?");      System.out.println("Old objective:" + (float)m_OldObjective			 + " < New objective: " + (float)m_Objective);     }//      // TEMPORARY BLAH//      System.out.println("\tvar=" + ((float)m_objVariance)//  			 + "\tC=" + ((float)m_objCannotLinks)//  			 + "\tM=" + ((float)m_objMustLinks)//  			 + "\tLOG=" + ((float)m_objNormalizer) //  			 + "\tREG=" + ((float)m_objRegularizer));        return m_Objective;  }    /** Actual KMeans function */  protected void runKMeans() throws Exception {    boolean converged = false;    m_Iterations = 0;    m_numBlankIterations = 0;     m_Objective = Double.POSITIVE_INFINITY;     if (!m_isOfflineMetric) {      // initialize normalizer and regularizer      m_metric.resetMetric();      // initialize max CL penalties      if (m_ConstraintsHash.size() > 0) {	m_maxCLPenalties = calculateMaxCLPenalties();      }    }    while (!converged) {      m_OldObjective = m_Objective;             // E-step      int numMovedPoints = findBestAssignments();      m_numBlankIterations = (numMovedPoints == 0) ? m_numBlankIterations+1 : 0;       {//if (m_verbose) {	calculateObjectiveFunction(false);      }	System.out.println((float)m_Objective + " - Objective function after point assignment(CALC)");	System.out.println("\tvar=" + ((float)m_objVariance) 			   + "\tC=" + ((float)m_objCannotLinks) 			   + "\tM=" + ((float)m_objMustLinks) 			   + "\tLOG=" + ((float)m_objNormalizer) 			   + "\tREG=" + ((float)m_objRegularizer));           // M-step      updateClusterCentroids();      System.out.println("\n" + m_Iterations + ". Objective function: " + ((float)m_Objective));      {//if (m_verbose) {	calculateObjectiveFunction(true);      }	System.out.println((float)m_Objective + " - Objective function after centroid estimation");	System.out.println("\tvar=" + ((float)m_objVariance)			   + "\tC=" + ((float)m_objCannotLinks)			   + "\tM=" + ((float)m_objMustLinks)			   + "\tLOG=" + ((float)m_objNormalizer) 			   + "\tREG=" + ((float)m_objRegularizer));      if (m_Trainable == TRAINING_INTERNAL && !m_isOfflineMetric) {	updateMetricWeights();	  	{//if (m_verbose) {	  calculateObjectiveFunction(true);	}	  System.out.println((float)m_Objective + " - Objective function after metric update");	  System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) +			     "\tM=" + ((float)m_objMustLinks)  + "\tLOG=" + ((float)m_objNormalizer) +			     "\tREG=" + ((float)m_objRegularizer));	  	if (m_ConstraintsHash.size() > 0) {	  m_maxCLPenalties = calculateMaxCLPenalties();	}      }      converged = convergenceCheck(m_OldObjective, m_Objective);      m_Iterations++;    }    if (m_verbose) {      System.out.println("Done clustering; top cluster features: ");      for (int i = 0; i < m_NumClusters; i++){	System.out.println("Centroid " + i);	TreeMap map = new TreeMap(Collections.reverseOrder());	Instance centroid= m_ClusterCentroids.instance(i);	for (int j = 0; j < centroid.numValues(); j++) {	  Attribute attr = centroid.attributeSparse(j);	  map.put(new Double(centroid.value(attr)), attr.name());	}	Iterator it = map.entrySet().iterator();	for (int j=0; j < 5 && it.hasNext(); j++) {	  Map.Entry entry = (Map.Entry) it.next();	  System.out.println("\t" + entry.getKey() + "\t" + entry.getValue());	}      }    }  }  /** reset the value of the objective function and all of its components */   public void resetObjective() {     m_Objective = 0;    m_objVariance = 0;    m_objCannotLinks = 0;    m_objMustLinks = 0;    m_objNormalizer = 0;    m_objRegularizer = 0;  }    /** Go through the cannot-link constraints and find the current maximum distance   * @return an array of maximum weighted distances.  If a single metric is used, maximum distance   * is calculated over the entire dataset */  // TODO:  non-datasetWide case is not debugged currently!!!  protected double[] calculateMaxCLPenalties() throws Exception {    double [] maxPenalties = null;    double [][] minValues = null;    double [][] maxValues = null;    int[] attrIdxs = null;     maxPenalties = new double[m_NumClusters];    m_maxCLPoints = new Instance[m_NumClusters][2];    m_maxCLDiffInstances = new Instance[m_NumClusters];    for (int i = 0; i < m_NumClusters; i++) {       m_maxCLPoints[i][0] = new Instance(m_Instances.numAttributes());      m_maxCLPoints[i][1] = new Instance(m_Instances.numAttributes());      m_maxCLPoints[i][0].setDataset(m_Instances);      m_maxCLPoints[i][1].setDataset(m_Instances);      m_maxCLDiffInstances[i] = new Instance(m_Instances.numAttributes());      m_maxCLDiffInstances[i].setDataset(m_Instances);    }    // TEMPORARY PLUG:  this was supposed to take care of WeightedDotp,    // but it turns out that with weighting similarity can be > 1. //      if (m_metric.m_fixedMaxDistance) {//        for (int i = 0; i < m_NumClusters; i++) {//  	maxPenalties[i] = m_metric.getMaxDistance(); //        }//        return maxPenalties; //      }         minValues = new double[m_NumClusters][m_metrics[0].getNumAttributes()];    maxValues = new double[m_NumClusters][m_metrics[0].getNumAttributes()];    attrIdxs = m_metrics[0].getAttrIndxs();    // temporary plug:  if this if the first iteration when no instances were assigned to clusters,    // dataset-wide (not cluster-wide!) minimum and maximum are used even for the case with    // multiple metrics    boolean datasetWide = true;    if (m_useMultipleMetrics && m_Iterations > 0) {       datasetWide = false;    }     // TODO:  Mahalanobis - check with getMaxPoints    // go through all points   //   if (m_useMultipleMetrics) {//        for (int i = 0; i < m_metrics.length; i++) { //  	double[][] maxPoints = ((WeightedMahalanobis)m_metrics[i]).getMaxPoints(m_ConstraintsHash, m_Instances);//  	minValues[i] = maxPoints[0];//  	  maxValues[i] = maxPoints[1];//  	  //  	  System.out.println("Max points " + i);//  	  //  	  for (int j = 0; j < maxPoints[0].length; j++) { System.out.println(maxPoints[0][j] + " - " + maxPoints[1][j]);}//  	}//        } else { 

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -