📄 hardpairwiseselector.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * HardPairwiseSelector.java * Copyright (C) 2002 Mikhail Bilenko * */package weka.core.metrics;import java.util.*;import java.io.Serializable;import weka.core.*;/** * HardPairwiseSelector class. Given a metric and training data, * create a set of "difficult" diff-class instance pairs that correspond to metric training data * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.3 $ */public class HardPairwiseSelector extends PairwiseSelector implements Serializable, OptionHandler { public static final int PAIRS_RANDOM = 1; public static final int PAIRS_HARDEST = 2; public static final int PAIRS_EASIEST = 4; public static final int PAIRS_INTERVAL = 8; public static final Tag[] TAGS_PAIR_SELECTION_MODE = { new Tag(PAIRS_RANDOM, "Random pairs"), new Tag(PAIRS_HARDEST, "Hardest pairs"), new Tag(PAIRS_EASIEST, "Easiest pairs"), new Tag(PAIRS_INTERVAL, "Pairs in a percentile range") }; protected int m_positivesMode = PAIRS_RANDOM; protected int m_negativesMode = PAIRS_RANDOM; /** We will need this reverse comparator class to get hardest pairs (those with the largest distance */ public class ReverseComparator implements Comparator { public int compare(Object o1, Object o2) { Comparable c = (Comparable) o1; return -1 * c.compareTo(o2); } } /** A default constructor */ public HardPairwiseSelector() { } /** * Provide an array of metric pairs metric using given training instances * * @param metric the metric to train * @param instances data to train the metric on * @exception Exception if training has gone bad. */ public ArrayList createPairList(Instances instances, int numPosPairs, int numNegPairs, Metric metric) throws Exception { ArrayList pairList = new ArrayList(); TreeSet posPairSet = null; TreeSet negPairSet = null; double [] posPairDistances = null; double [] negPairDistances = null; Iterator iterator = null; int numActualPositives = 0, numActualNegatives = 0; // INITIALIZE initSelector(instances); System.out.println("m_numPotentialPositives=" + m_numPotentialPositives + "\tm_numPotentialNegatives=" + m_numPotentialNegatives); // SELECT POSITIVE PAIRS switch (m_positivesMode) { case PAIRS_EASIEST: posPairSet = new TreeSet(); posPairDistances = populatePositivePairSet(metric, posPairSet); pairList = getUniquePairs(posPairSet, metric, numPosPairs); break; case PAIRS_HARDEST: posPairSet = new TreeSet(new ReverseComparator()); posPairDistances = populatePositivePairSet(metric, posPairSet); pairList = getUniquePairs(posPairSet, metric, numPosPairs); break; case PAIRS_RANDOM: // go through lists of instances for each class and create a list of *all* positive pairs ArrayList posPairList = new ArrayList(); iterator = m_classInstanceMap.values().iterator(); while (iterator.hasNext()) { ArrayList instanceList = (ArrayList) iterator.next(); for (int i = 0; i < instanceList.size(); i++) { Instance instance1 = (Instance) instanceList.get(i); for (int j = i+1; j < instanceList.size(); j++) { Instance instance2 = (Instance) instanceList.get(j); TrainingPair pair = new TrainingPair(instance1, instance2, true, metric.distance(instance1, instance2)); posPairList.add(pair); } } } // if we have fewer pairs available than requested, return all the ones that were created if (posPairList.size() <= numPosPairs) { pairList = posPairList; } else { // if we have enough potential pairs, sample randomly with replacement Random random = new Random(); for (int i = 0; i < numPosPairs; i++) { int idx = random.nextInt(posPairList.size()); TrainingPair pair = (TrainingPair) posPairList.remove(idx); pairList.add(pair); } } break; case PAIRS_INTERVAL: System.err.println("TODO PAIRS_INTERVAL!!!"); break; default: throw new Exception("Unknown method for selecting positive pairs: " + m_positivesMode); } numActualPositives = pairList.size(); // SELECT NEGATIVE PAIRS switch (m_negativesMode) { case PAIRS_EASIEST: // Create a map with *all* negatives negPairSet = new TreeSet(new ReverseComparator()); negPairDistances = populateNegativePairSet(metric, negPairSet); pairList.addAll(getUniquePairs(negPairSet, metric, numNegPairs)); case PAIRS_HARDEST: negPairSet = new TreeSet(); negPairDistances = populateNegativePairSet(metric, negPairSet); pairList.addAll(getUniquePairs(negPairSet, metric, numNegPairs)); break; case PAIRS_RANDOM: // create all negative pairs and sample randomly ArrayList negPairList = new ArrayList(); // go through lists of instances for each class for (int i = 0; i < m_classValueList.size(); i++) { ArrayList instanceList1 = (ArrayList) m_classInstanceMap.get(m_classValueList.get(i)); for (int j = 0; j < instanceList1.size(); j++) { Instance instance1 = (Instance) instanceList1.get(j); // create all pairs from other clusters with this instance for (int k = i+1; k < m_classValueList.size(); k++) { ArrayList instanceList2 = (ArrayList) m_classInstanceMap.get(m_classValueList.get(k)); for (int l = 0; l < instanceList2.size(); l++) { Instance instance2 = (Instance) instanceList2.get(l); TrainingPair pair = new TrainingPair(instance1, instance2, false, metric.distance(instance1, instance2)); negPairList.add(pair); } } } } // if we have fewer pairs available than requested, return all the ones that were created if (negPairList.size() <= numNegPairs) { pairList.addAll(negPairList); } else { // if we have enough potential pairs, randomly sample with replacement Random random = new Random(); for (int i = 0; i < numNegPairs; i++) { int idx = random.nextInt(negPairList.size()); TrainingPair pair = (TrainingPair) negPairList.remove(idx); pairList.add(pair); } } break; case PAIRS_INTERVAL: System.err.println("TODO PAIRS_INTERVAL!!!"); break; default: throw new Exception("Unknown method for selecting positive pairs: " + m_positivesMode); } numActualNegatives = pairList.size() - numActualPositives; System.out.println(); System.out.println("POSITIVES: requested=" + numPosPairs + "\tpossible=" + m_numPotentialPositives + "\tactual=" + numActualPositives); System.out.println("NEGATIVES: requested=" + numNegPairs + "\tpossible=" + m_numPotentialNegatives + "\tactual=" + numActualNegatives); return pairList; } /** This helper method goes through a TreeSet containing sorted TrainingPairs * and returns a list of unique pairs * @param pairSet a sorted set of TrainingPair's * @param metric the metric that is used for creating DiffInstance's * @param numPairs the number of desired pairs * @return a list with training pairs */ protected ArrayList getUniquePairs(TreeSet pairSet, Metric metric, int numPairs) { ArrayList pairList = new ArrayList(); HashMap checksumMap = new HashMap(); Iterator iterator = pairSet.iterator(); for (int i = 0; iterator.hasNext() && i < numPairs; i++) { TrainingPair pair = (TrainingPair) iterator.next(); if (metric instanceof LearnableMetric) { Instance diffInstance = ((LearnableMetric)metric).createDiffInstance(pair.instance1, pair.instance2); double checksum = 0; for (int j = 0; j < diffInstance.numValues(); j++) { checksum += j*17 * diffInstance.value(j); } // round off to help with machine precision errors checksum = (float) checksum; // if this checksum was encountered before, get a list of instances // that have this checksum, and check if any of them are dupes of this one if (checksumMap.containsKey(new Double(checksum))) { ArrayList checksumList = (ArrayList) checksumMap.get(new Double(checksum)); System.out.println("Collision for " + checksum + ": " + checksumList.size()); boolean unique = true; for (int k = 0; k < checksumList.size() && unique; k++) { Instance nextDiffInstance = (Instance) checksumList.get(k); unique = false; for (int l = 0; l < nextDiffInstance.numValues() && !unique; l++) { if (((float)nextDiffInstance.value(l)) != ((float)diffInstance.value(l))) { unique = true; } } if (!unique) { // This is a dupe! System.out.println("Dupe!"); i--; break; } } if (unique) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -