⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 hardpairwiseselector.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    HardPairwiseSelector.java *    Copyright (C) 2002 Mikhail Bilenko * */package weka.core.metrics;import java.util.*;import java.io.Serializable;import weka.core.*;/**  *  HardPairwiseSelector class.  Given a metric and training data, * create a set of "difficult" diff-class instance pairs that correspond to metric training data * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.3 $ */public class HardPairwiseSelector extends PairwiseSelector implements Serializable, OptionHandler {  public static final int PAIRS_RANDOM = 1;  public static final int PAIRS_HARDEST = 2;  public static final int PAIRS_EASIEST = 4;  public static final int PAIRS_INTERVAL = 8;  public static final Tag[] TAGS_PAIR_SELECTION_MODE = {    new Tag(PAIRS_RANDOM, "Random pairs"),    new Tag(PAIRS_HARDEST, "Hardest pairs"),    new Tag(PAIRS_EASIEST, "Easiest pairs"),    new Tag(PAIRS_INTERVAL, "Pairs in a percentile range")  };  protected int m_positivesMode = PAIRS_RANDOM;  protected int m_negativesMode = PAIRS_RANDOM;  /** We will need this reverse comparator class to get hardest pairs (those with the largest distance */  public class ReverseComparator implements Comparator {    public int compare(Object o1, Object o2) {      Comparable c = (Comparable) o1;      return -1 * c.compareTo(o2);    }  }  /** A default constructor */  public HardPairwiseSelector() {  }   /**   * Provide an array of metric pairs metric using given training instances   *   * @param metric the metric to train   * @param instances data to train the metric on   * @exception Exception if training has gone bad.   */  public ArrayList createPairList(Instances instances, int numPosPairs, int numNegPairs, Metric metric) throws Exception {    ArrayList pairList = new ArrayList();    TreeSet posPairSet = null;    TreeSet negPairSet = null;    double [] posPairDistances = null;    double [] negPairDistances = null;    Iterator iterator = null;    int numActualPositives = 0, numActualNegatives = 0;    // INITIALIZE    initSelector(instances);    System.out.println("m_numPotentialPositives=" + m_numPotentialPositives + "\tm_numPotentialNegatives=" + m_numPotentialNegatives);    // SELECT POSITIVE PAIRS    switch (m_positivesMode) {    case PAIRS_EASIEST:      posPairSet = new TreeSet();      posPairDistances = populatePositivePairSet(metric, posPairSet);      pairList = getUniquePairs(posPairSet, metric, numPosPairs);      break;    case PAIRS_HARDEST:      posPairSet = new TreeSet(new ReverseComparator());      posPairDistances = populatePositivePairSet(metric, posPairSet);      pairList = getUniquePairs(posPairSet, metric, numPosPairs);      break;    case PAIRS_RANDOM:      // go through lists of instances for each class and create a list of *all* positive pairs      ArrayList posPairList = new ArrayList();      iterator = m_classInstanceMap.values().iterator();      while (iterator.hasNext()) {	ArrayList instanceList = (ArrayList) iterator.next();	for (int i = 0; i < instanceList.size(); i++) {	  Instance instance1 = (Instance) instanceList.get(i);	  for (int j = i+1; j < instanceList.size(); j++) {	    Instance instance2 = (Instance) instanceList.get(j);	    TrainingPair pair = new TrainingPair(instance1, instance2, true, metric.distance(instance1, instance2));	    posPairList.add(pair);	  } 	}      }      // if we have fewer pairs available than requested, return all the ones that were created      if (posPairList.size() <= numPosPairs) {	pairList = posPairList;      } else { // if we have enough potential pairs, sample randomly with replacement	Random random = new Random();	for (int i = 0; i < numPosPairs; i++) {	  int idx = random.nextInt(posPairList.size());	  TrainingPair pair = (TrainingPair) posPairList.remove(idx);	  pairList.add(pair);	}      }      break;    case PAIRS_INTERVAL:      System.err.println("TODO PAIRS_INTERVAL!!!");      break;    default:      throw new Exception("Unknown method for selecting positive pairs: " + m_positivesMode);    }    numActualPositives = pairList.size();        // SELECT NEGATIVE PAIRS    switch (m_negativesMode) {    case PAIRS_EASIEST:      // Create a map with *all* negatives      negPairSet = new TreeSet(new ReverseComparator());      negPairDistances = populateNegativePairSet(metric, negPairSet);      pairList.addAll(getUniquePairs(negPairSet, metric, numNegPairs));    case PAIRS_HARDEST:      negPairSet = new TreeSet();      negPairDistances = populateNegativePairSet(metric, negPairSet);      pairList.addAll(getUniquePairs(negPairSet, metric, numNegPairs));      break;    case PAIRS_RANDOM:         // create all negative pairs and sample randomly      ArrayList negPairList = new ArrayList();      // go through lists of instances for each class      for (int i = 0; i < m_classValueList.size(); i++) {	ArrayList instanceList1 = (ArrayList) m_classInstanceMap.get(m_classValueList.get(i));	for (int j = 0; j < instanceList1.size(); j++) {	  Instance instance1 = (Instance) instanceList1.get(j);	  // create all pairs from other clusters with this instance	  for (int k = i+1; k < m_classValueList.size(); k++) {	    ArrayList instanceList2 = (ArrayList) m_classInstanceMap.get(m_classValueList.get(k));	    for (int l = 0; l < instanceList2.size(); l++) {	      Instance instance2 = (Instance) instanceList2.get(l);	      TrainingPair pair = new TrainingPair(instance1, instance2, false, metric.distance(instance1, instance2));	      negPairList.add(pair);	    }	  }	}      }      // if we have fewer pairs available than requested, return all the ones that were created      if (negPairList.size() <= numNegPairs) {	pairList.addAll(negPairList);      } else { // if we have enough potential pairs, randomly sample with replacement	Random random = new Random();	for (int i = 0; i < numNegPairs; i++) {	  int idx = random.nextInt(negPairList.size());	  TrainingPair pair = (TrainingPair) negPairList.remove(idx);	  pairList.add(pair);	}      }      break;    case PAIRS_INTERVAL:       System.err.println("TODO PAIRS_INTERVAL!!!");      break;    default:      throw new Exception("Unknown method for selecting positive pairs: " + m_positivesMode);    }    numActualNegatives = pairList.size() - numActualPositives;    System.out.println();    System.out.println("POSITIVES:  requested=" + numPosPairs + "\tpossible=" + m_numPotentialPositives +		       "\tactual=" + numActualPositives);    System.out.println("NEGATIVES:  requested=" + numNegPairs + "\tpossible=" + m_numPotentialNegatives +		       "\tactual=" + numActualNegatives);    return pairList;  }  /** This helper method goes through a TreeSet containing sorted TrainingPairs   * and returns a list of unique pairs   * @param pairSet a sorted set of TrainingPair's   * @param metric the metric that is used for creating DiffInstance's   * @param numPairs the number of desired pairs   * @return a list with training pairs   */        protected ArrayList getUniquePairs(TreeSet pairSet, Metric metric, int numPairs) {    ArrayList pairList = new ArrayList();    HashMap checksumMap = new HashMap();    Iterator iterator = pairSet.iterator();    for (int i = 0; iterator.hasNext() && i < numPairs; i++) {      TrainingPair pair = (TrainingPair) iterator.next();      if (metric instanceof LearnableMetric) {	Instance diffInstance = ((LearnableMetric)metric).createDiffInstance(pair.instance1, pair.instance2);	double checksum = 0;	for (int j = 0; j < diffInstance.numValues(); j++) {	  checksum += j*17 * diffInstance.value(j);	}	// round off to help with machine precision errors	checksum = (float) checksum;	// if this checksum was encountered before, get a list of instances	// that have this checksum, and check if any of them are dupes of this one	if (checksumMap.containsKey(new Double(checksum))) {	  ArrayList checksumList = (ArrayList) checksumMap.get(new Double(checksum));	  System.out.println("Collision for " + checksum + ": " + checksumList.size());	  boolean unique = true;	  for (int k = 0; k < checksumList.size() && unique; k++) {	    Instance nextDiffInstance = (Instance) checksumList.get(k);	    unique = false;	    for (int l = 0; l < nextDiffInstance.numValues() && !unique; l++) {	      if (((float)nextDiffInstance.value(l)) != ((float)diffInstance.value(l))) {		unique = true;	      } 	    }	    if (!unique) {	      // This is a dupe!	      System.out.println("Dupe!");	      i--;	      break;	    } 	  }	  if (unique) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -