⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 weuclideangdlearner.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    WEuclideanLearnerGD.java *    Copyright (C) 2004 Mikhail Bilenko and Sugato Basu * */package weka.clusterers.metriclearners; import java.util.*;import weka.core.*;import weka.core.metrics.*;import weka.clusterers.MPCKMeans;import weka.clusterers.InstancePair;/**  * A gradient-descent based learner for WeightedEuclidean * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) and Sugato Basu * (sugato@cs.utexas.edu) * @version $Revision: 1.5 $ */public class WEuclideanGDLearner extends GDMetricLearner {  /** if clusterIdx is -1, all instances are used   * (a single metric for all clusters is used) */     public boolean trainMetric(int clusterIdx) throws Exception {    Init(clusterIdx);    double [] gradients = new double[m_numAttributes];    Instance diffInstance;    int violatedConstraints = 0;     double [] currentWeights = m_metric.getWeights();    int numInstances = m_instances.numInstances();    double [] regularizerComponents = new double[m_numAttributes];    if (m_regularize) {      regularizerComponents = InitRegularizerComponents(currentWeights);     }        for (int instIdx = 0; instIdx < m_instances.numInstances(); instIdx++) {      int assignment = m_clusterAssignments[instIdx];      // only instances assigned to this cluster are of importance      if (assignment == clusterIdx || clusterIdx == -1) {	Instance instance = m_instances.instance(instIdx); 	numInstances++;	if (clusterIdx < 0) {	  m_centroid = m_kmeans.getClusterCentroids().instance(assignment); 	}		diffInstance = m_metric.createDiffInstance(instance, m_centroid);	for (int attr = 0; attr < m_numAttributes; attr++) {	  gradients[attr] +=  diffInstance.value(attr); // Euclidean components	  if (currentWeights[attr] > 0) {	    gradients[attr] -= m_logTermWeight/currentWeights[attr]; // log components//  	    if (m_regularize) {//  	      regularizerComponents[attr] = m_regularizerTermWeight *//  		m_metric.getRegularizer().gradient(currentWeights[attr]);//  	    } else {//  	      regularizerComponents[attr] = 0;//  	    }	  }	}	// go through violated constraints	Object list =  m_instanceConstraintMap.get(new Integer(instIdx));	if (list != null) {   // there are constraints associated with this instance	  ArrayList constraintList = (ArrayList) list;	  for (int i = 0; i < constraintList.size(); i++) {	    InstancePair pair = (InstancePair) constraintList.get(i);	    int linkType = pair.linkType;	    int firstIdx = pair.first;	    int secondIdx = pair.second;	    Instance instance1 = m_instances.instance(firstIdx);	    Instance instance2 = m_instances.instance(secondIdx);	    int otherIdx = (firstIdx == instIdx) ?	      m_clusterAssignments[secondIdx] : m_clusterAssignments[firstIdx];	    // check whether the constraint is violated	    if (otherIdx != -1) {  	      if (otherIdx != assignment && linkType == InstancePair.MUST_LINK) {		diffInstance = m_metric.createDiffInstance(instance1, instance2);		for (int attr = 0; attr < m_numAttributes; attr++) {		  gradients[attr] += 0.5 * m_MLweight * diffInstance.value(attr);		}		violatedConstraints++; 	      } else if (otherIdx == assignment && linkType == InstancePair.CANNOT_LINK){		diffInstance = m_metric.createDiffInstance(instance1, instance2);		for (int attr = 0; attr < m_numAttributes; attr++) {		  // this constraint will be counted twice, hence 0.5		  gradients[attr] += 0.5 * m_CLweight * m_maxCLDiffInstance.value(attr);		  gradients[attr] -= 0.5 * m_CLweight * diffInstance.value(attr); 		}		violatedConstraints++; 	      }	    } // end while	  }	}      }    }    double [] newWeights = GDUpdate(currentWeights, gradients, regularizerComponents);     m_metric.setWeights(newWeights);        System.out.println("   Total constraints violated: " + violatedConstraints/2);     return true;       }/// OLD CODE FOR MULTIPLE//    /** M-step of the KMeans clustering algorithm -- updates Euclidean//     *  metric weights for the individual metrics using gradient//     *  descent. Invoked only when m_regularizeWeights is true and//     *  metric is trainable *///    protected boolean updateMultipleMetricWeightsEuclideanGD() throws Exception {//      // SUGATO: Added regularization code to updateMultipleMetricWeightsEuclideanGD//      int numAttributes = m_Instances.numAttributes();//      double[][] gradients = new double[m_NumClusters][numAttributes];//      double[][] regularizerComponents = new double[m_NumClusters][numAttributes];//      int [] clusterCounts = new int[m_NumClusters]; // count how many instances are in each cluster//      Instance diffInstance;//      double [][] currentWeights = new double[m_NumClusters][numAttributes];//      for (int i=0; i<m_NumClusters; i++) {//        currentWeights[i] = ((LearnableMetric) m_metrics[i]).getWeights();//      }//      //begin debugging variance//      boolean debugVariance = true; //      double[][] trueWeights = new double[m_NumClusters][numAttributes];//      int [] majorityClasses = new int[m_NumClusters];//      int [][] classCounts  = new int[m_NumClusters][m_TotalTrainWithLabels.numClasses()];//      // get the majority counts//      // NB:  m_TotalTrainWithLabels does *not* include unlabeled data, counts here are undersampled!//      // assuming unlabeled data came from same distribution as m_TotalTrainWithLabels, counts are still valid...//      for (int instIdx=0; instIdx<m_TotalTrainWithLabels.numInstances(); instIdx++) {//        Instance fullInstance = m_TotalTrainWithLabels.instance(instIdx);//        classCounts[m_ClusterAssignments[instIdx]][(int)(fullInstance.classValue())]++;//      }//      for (int i = 0; i < m_NumClusters; i++){//        int majorityClass = 0;//        System.out.print("Cluster" + i + "\t" + classCounts[i][0]);//        for (int j = 1; j < m_TotalTrainWithLabels.numClasses(); j++) {//  	System.out.print("\t" + classCounts[i][j]);//  	if (classCounts[i][j] > classCounts[i][majorityClass]) {//  	  majorityClass = j;//  	}//        }//        System.out.println();//        majorityClasses[i] = majorityClass;//      }//      class MajorityChecker {//        int [] m_majorityClasses  = null; //        public MajorityChecker(int [] majClasses) { m_majorityClasses = majClasses;}//        public  boolean belongsToMajority(Instances instances, int instIdx, int centroidIdx) {//  	// silly, must pass instance since can't access outer class fields otherwise from a local inner class//  	Instance fullInstance = instances.instance(instIdx); //  	int classValue = (int) fullInstance.classValue();//  	if (classValue == m_majorityClasses[centroidIdx]) {//  	  return true;//  	} else {//  	  return false;//  	}//        }//      }//      MajorityChecker majChecker = new MajorityChecker(majorityClasses);//      //end debugging variance    //      int violatedConstraints = 0; //      for (int i=0; i<m_NumClusters; i++){//        for (int attr=0; attr<numAttributes; attr++) {//  	regularizerComponents[i][attr] = 0;//        }//      }//      for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {//        int centroidIdx = m_ClusterAssignments[instIdx];//        diffInstance = m_metrics[centroidIdx].createDiffInstance(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx));//        for (int attr=0; attr<numAttributes; attr++) {//  	gradients[centroidIdx][attr] += diffInstance.value(attr); // Mahalanobis components//  	if (currentWeights[centroidIdx][attr] > 0) {//  	  gradients[centroidIdx][attr] -= 1/currentWeights[centroidIdx][attr]; // log components//  	  if (m_regularizeWeights) {//  	    regularizerComponents[centroidIdx][attr] += m_currregularizerTermWeight/(currentWeights[centroidIdx][attr] * currentWeights[centroidIdx][attr]);//  	  } else {//  	    regularizerComponents[centroidIdx][attr] = 0;//  	  }//  	}//  	if (debugVariance && instIdx < m_TotalTrainWithLabels.numInstances()) {//  	  if (majChecker.belongsToMajority(m_TotalTrainWithLabels, instIdx, centroidIdx)) {//  	    trueWeights[centroidIdx][attr] += diffInstance.value(attr);//  	  } //  	}//        }//        clusterCounts[centroidIdx]++;//        Object list =  m_instanceConstraintHash.get(new Integer(instIdx));//        if (list != null) {   // there are constraints associated with this instance//  	ArrayList constraintList = (ArrayList) list;//  	for (int i = 0; i < constraintList.size(); i++) {//  	  InstancePair pair = (InstancePair) constraintList.get(i);//  	  int firstIdx = pair.first;//  	  int secondIdx = pair.second;//  	  double cost = 0;//  	  if (pair.linkType == InstancePair.MUST_LINK) {//  	    cost = m_MLweight;//  	  } else if (pair.linkType == InstancePair.CANNOT_LINK) {//  	      cost = m_CLweight;//  	  }//  	  Instance instance1 = m_Instances.instance(firstIdx);//  	  Instance instance2 = m_Instances.instance(secondIdx);//  	  int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];//  	  // check whether the constraint is violated//  	  if (otherIdx != -1) {  //  	    if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { // violated must-link//  	      if (m_verbose) {//  		System.out.println("Found violated must link between: " + firstIdx + " and " + secondIdx);//  	      }//  	      // we penalize weights for both clusters involved, splitting the penalty in half//  	      Instance diffInstance1 = m_metrics[otherIdx].createDiffInstance(instance1, instance2);//  	      Instance diffInstance2 = m_metrics[centroidIdx].createDiffInstance(instance1, instance2);	      //  	      for (int attr=0; attr<numAttributes; attr++) {  // double-counting constraints, hence 0.5*0.5//  		gradients[otherIdx][attr] += 0.25 * cost * diffInstance1.value(attr);//  		gradients[centroidIdx][attr] += 0.25 * cost * diffInstance2.value(attr);//  	      }//  	      violatedConstraints++; //  	    }	      //  	    else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { //violated cannot-link//  	      if (m_verbose) {//  		System.out.println("Found violated cannot link between: " + firstIdx + " and " + secondIdx);//  	      }//  	      // we penalize weights for just one cluster involved//  	      diffInstance = m_metrics[centroidIdx].createDiffInstance(instance1, instance2);//  	      Instance cannotDiffInstance = m_metrics[otherIdx].createDiffInstance(m_maxCLPoints[centroidIdx][0],//  										   m_maxCLPoints[centroidIdx][1]);//  	      for (int attr=0; attr<numAttributes; attr++) {  // double-counting constraints, hence 0.5//  		gradients[centroidIdx][attr] += 0.5 * cost * cannotDiffInstance.value(attr);//  		gradients[centroidIdx][attr] -= 0.5 * cost * diffInstance.value(attr); //  	      }//  	      violatedConstraints++; //  	    }//  	  } // end while//  	}//        }//      }//      double [][] newWeights = new double[m_metrics.length][numAttributes];//      for (int i = 0; i < m_metrics.length; i++) { //        for (int attr=0; attr<numAttributes; attr++) {//  	gradients[i][attr] *= m_currEta;//  	if (gradients[i][attr] > regularizerComponents[i][attr]) { // to take into account the direction of the gradient descent update//  	  newWeights[i][attr] = currentWeights[i][attr] - gradients[i][attr] + regularizerComponents[i][attr];//  	} else {//  	  newWeights[i][attr] = currentWeights[i][attr] + gradients[i][attr] - regularizerComponents[i][attr];//  	}//  	if (newWeights[i][attr] < 0) {//  	  System.out.println("prevented negative weight " + newWeights[i][attr] + " for attribute " + m_Instances.attribute(attr).name()); //  	  newWeights[i][attr] = 0;//  	} else if (newWeights[i][attr] == 0) {//  	  System.out.println("zero weight for attribute " + m_Instances.attribute(attr).name()); //  	}//        }//        newWeights[i] = ClusterUtils.normalize(newWeights[i]);//        ((LearnableMetric) m_metrics[i]).setWeights(newWeights[i]);//        // PRINT top weights//        System.out.println("Cluster " + i + " (" + clusterCounts[i] + ")"); //        TreeMap map = new TreeMap(Collections.reverseOrder());//        for (int j = 0; j < newWeights[i].length; j++) {//  	map.put(new Double(newWeights[i][j]), new Integer(j));//        }//        Iterator it = map.entrySet().iterator();//        for (int j=0; j < 10 && it.hasNext(); j++) {//  	Map.Entry entry = (Map.Entry) it.next();//  	int idx = ((Integer)entry.getValue()).intValue();//  	System.out.println("\t" + m_Instances.attribute(idx).name() + "\t" + newWeights[i][idx] //  			   + "\tgradient=" + gradients[i][idx] + "\tregularizer=" + regularizerComponents[i][idx]);//        }//        // end PRINT top weights//      }//      m_currEta = m_currEta * m_etaDecayRate; //      //    m_currregularizerTermWeight *= m_etaDecayRate;    //      // PRINT routine//      //      System.out.println("Total constraints violated: " + violatedConstraints/2 + "; weights are:"); //      //      for (int attr=0; attr<numAttributes; attr++) {//      //        System.out.print(newWeights[attr] + "\t");//      //      }//      //      System.out.println();//      // end PRINT routine//      return true;//    }  /**   * Gets the current settings of KL   *   * @return an array of strings suitable for passing to setOptions()   */  public String [] getOptions() {    String [] options = new String [1];    int current = 0;    while (current < options.length) {      options[current++] = "";    }    return options;  }  public void setOptions(String[] options) throws Exception {    // TODO: add later   }  public Enumeration listOptions() {    // TODO: add later     return null;  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -