📄 semisuppairactivecurvecvresultproducer.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SemiSupPairActiveCurveCVResultProducer.java * Copyright (C) 2002 Sugato Basu * */package weka.experiment;import java.util.*;import java.io.*;import weka.core.*;import java.text.SimpleDateFormat;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;import weka.clusterers.*;/** * N-fold cross-validation learning curve for pairwise active learning * in semi-supervised learners (clusterers and classifiers) * * @author Sugato Basu */public class SemiSupPairActiveCurveCVResultProducer implements ResultProducer, OptionHandler, AdditionalMeasureProducer { /** The dataset of interest */ protected Instances m_Instances; /** The ResultListener to send results to */ protected ResultListener m_ResultListener = new CSVResultListener(); /** The number of folds in the cross-validation */ protected int m_NumFolds = 2; /** Whether transductive evaluation is to be performed */ protected boolean m_IsTransductive = true; /** Proportion of must-link pairs in the training set; varies from 0 to 1; * -1 corresponds to random sampling*/ protected double m_fractionMustLinks = -1; /** Whether active learning is to be performed */ protected boolean m_DoActive = false; /** Save raw output of split evaluators --- for debugging purposes */ protected boolean m_debugOutput = false; /** The output zipper to use for saving raw splitEvaluator output */ protected OutputZipper m_ZipDest = null; /** The destination output file/directory for raw output */ protected File m_OutputFile = new File( new File(System.getProperty("user.dir")), "splitEvalutorOut.zip"); /** The SplitEvaluator used to generate results */ protected SplitEvaluator m_SplitEvaluator = new SemiSupClustererSplitEvaluator(); /** The names of any additional measures to look for in SplitEvaluators */ protected String [] m_AdditionalMeasures = null; /** * The minimum number of instances to use. If this is zero, the first * step will contain m_StepSize instances */ protected int m_LowerSize = 0; // /** // * Algorithm types - no active learning, single-point active learning // * or pairwise active learning // */ // public static final int NO_ACTIVE = 1; // public static final int SINGLE_ACTIVE = 2; // public static final int PAIR_ACTIVE = 3; // public static final Tag [] TAGS_ACTIVE = { // new Tag(NO_ACTIVE, "None"), // new Tag(SINGLE_ACTIVE, "Single point"), // new Tag(PAIR_ACTIVE, "Pairwise") // }; // /** Type of active algorithm */ // protected int m_ActiveType = SINGLE_ACTIVE; /** * The maximum number of instances to use. -1 indicates no maximum * (other than the total number of instances) */ protected int m_UpperSize = -1; /** The number of instances to add at each step */ protected int m_StepSize = 10; /** The specific points to plot, either integers representing specific numbers of training examples, * or decimal fractions representing percentages of the full training set*/ protected double[] m_PlotPoints = new double[] {0, 200, 500}; /** The current dataset size during stepping */ protected int m_CurrentSize = 0; /* The name of the key field containing the dataset name */ public static String DATASET_FIELD_NAME = "Dataset"; /* The name of the key field containing the run number */ public static String RUN_FIELD_NAME = "Run"; /* The name of the key field containing the fold number */ public static String FOLD_FIELD_NAME = "Fold"; /* The name of the result field containing the timestamp */ public static String TIMESTAMP_FIELD_NAME = "Date_time"; /* The name of the key field containing the learning rate step number */ public static String STEP_FIELD_NAME = "Total_instances"; /* The name of the key field containing the fraction of total instances used */ public static String FRACTION_FIELD_NAME = "Fraction_instances"; /* Indicates whether fractions or actual number of instances have been specified */ protected boolean m_IsFraction = false; /** * Returns a string describing this result producer * @return a description of the result producer suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Performs a semi-supervised learning-curve cross validation run using a supplied semi-supervised split evaluator. In the inductive framework, the semi-supervised learner for each split is trained on a fixed size of labeled + unlabeled training data, with the proportion of labeled training data being increased at each point along the learning curve. Testing is performed on the test set for that split after training. In the transductive framework, the unlabeled testing data is also added to the pool of unlabeled training data, and as in the inductive framework the proportion of labeled training data is increased at each point along the learning curve. Testing is performed as usual on the test set for that split after training. "; } /** * Sets the dataset that results will be obtained for. * * @param instances a value of type 'Instances'. */ public void setInstances(Instances instances) { m_Instances = instances; } /** * Sets the object to send results of each run to. * * @param listener a value of type 'ResultListener' */ public void setResultListener(ResultListener listener) { m_ResultListener = listener; } /** * Set a list of method names for additional measures to look for * in SplitEvaluators. This could contain many measures (of which only a * subset may be produceable by the current SplitEvaluator) if an experiment * is the type that iterates over a set of properties. * @param additionalMeasures an array of measure names, null if none */ public void setAdditionalMeasures(String [] additionalMeasures) { m_AdditionalMeasures = additionalMeasures; if (m_SplitEvaluator != null) { System.err.println(" SemiSupPairActiveCurveCVResultProducer: setting additional " +"measures for " +"split evaluator"); m_SplitEvaluator.setAdditionalMeasures(m_AdditionalMeasures); } } /** * Returns an enumeration of any additional measure names that might be * in the SplitEvaluator * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(); if (m_SplitEvaluator instanceof AdditionalMeasureProducer) { Enumeration en = ((AdditionalMeasureProducer)m_SplitEvaluator). enumerateMeasures(); while (en.hasMoreElements()) { String mname = (String)en.nextElement(); newVector.addElement(mname); } } return newVector.elements(); } /** * Returns the value of the named measure * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (m_SplitEvaluator instanceof AdditionalMeasureProducer) { return ((AdditionalMeasureProducer)m_SplitEvaluator). getMeasure(additionalMeasureName); } else { throw new IllegalArgumentException(" SemiSupPairActiveCurveCVResultProducer: " +"Can't return value for : "+additionalMeasureName +". "+m_SplitEvaluator.getClass().getName()+" " +"is not an AdditionalMeasureProducer"); } } /** * Gets a Double representing the current date and time. * eg: 1:46pm on 20/5/1999 -> 19990520.1346 * * @return a value of type Double */ public static Double getTimestamp() { Calendar now = Calendar.getInstance(TimeZone.getTimeZone("UTC")); double timestamp = now.get(Calendar.YEAR) * 10000 + (now.get(Calendar.MONTH) + 1) * 100 + now.get(Calendar.DAY_OF_MONTH) + now.get(Calendar.HOUR_OF_DAY) / 100.0 + now.get(Calendar.MINUTE) / 10000.0; return new Double(timestamp); } /** * Prepare to generate results. * * @exception Exception if an error occurs during preprocessing. */ public void preProcess() throws Exception { if (m_SplitEvaluator == null) { throw new Exception("No SplitEvalutor set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } m_ResultListener.preProcess(this); } /** * Perform any postprocessing. When this method is called, it indicates * that no more requests to generate results for the current experiment * will be sent. * * @exception Exception if an error occurs */ public void postProcess() throws Exception { m_ResultListener.postProcess(this); if (m_debugOutput) { if (m_ZipDest != null) { m_ZipDest.finished(); m_ZipDest = null; } } } /** * Get the maximum size of the training set based on pairwise * points, using upperSize limit or maximum pairs in training set * size from the n-fold CV */ protected int maxTrainSize() { if (m_UpperSize == -1 || m_PlotPoints != null) { int numTrain = (int) (m_Instances.numInstances()*(1 - 1/((double)m_NumFolds))); int maxTrainSize = (int)(numTrain*(numTrain-1)/2); // System.out.println("NumTrain: " + numTrain + ", maxTrainSize: " + maxTrainSize); return maxTrainSize; } else { return m_UpperSize; } } /** * Gets the keys for a specified run number. Different run * numbers correspond to different randomizations of the data. Keys * produced should be sent to the current ResultListener * * @param run the run number to get keys for. * @exception Exception if a problem occurs while getting the keys */ public void doRunKeys(int run) throws Exception { int numExtraKeys; if(m_IsFraction) numExtraKeys = 5; else numExtraKeys = 4; if (m_Instances == null) { throw new Exception("No Instances set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } for (int fold = 0; fold < m_NumFolds; fold++) { int pointNum = 0; // For each subsample size if (m_PlotPoints != null) { m_CurrentSize = plotPoint(0); } else if (m_LowerSize == 0) { m_CurrentSize = m_StepSize; } else { m_CurrentSize = m_LowerSize; } while (m_CurrentSize <= maxTrainSize()) { // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_SplitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_CurrentSize; if(m_IsFraction) key[4] = "" + m_PlotPoints[pointNum]; System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { try { m_ResultListener.acceptResult(this, key, null); } catch (Exception ex) { // Save the train and test datasets for debugging purposes? throw ex; } } if (m_PlotPoints != null) { pointNum ++; m_CurrentSize = plotPoint(pointNum); } else { m_CurrentSize += m_StepSize; } } } } /** * Gets the results for a specified run number. Different run * numbers correspond to different randomizations of the data. Results * produced should be sent to the current ResultListener * * @param run the run number to get results for. * @exception Exception if a problem occurs while getting the results */ public void doRun(int run) throws Exception { int numExtraKeys; if(m_IsFraction) numExtraKeys = 5; else numExtraKeys = 4; if (getRawOutput()) { if (m_ZipDest == null) { m_ZipDest = new OutputZipper(m_OutputFile); } } if (m_Instances == null) { throw new Exception("No Instances set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } // Randomize on a copy of the original dataset Instances runInstances = new Instances(m_Instances); runInstances.randomize(new Random(run)); if (runInstances.classAttribute().isNominal()) { runInstances.stratify(m_NumFolds); } for (int fold = 0; fold < m_NumFolds; fold++) { Instances train = runInstances.trainCV(m_NumFolds, fold); // Randomly shuffle stratified training set for fold train.randomize(new Random(fold)); Instances test = runInstances.testCV(m_NumFolds, fold); // For each subsample size int pointNum = 0; // For each subsample size if (m_PlotPoints != null) { m_CurrentSize = plotPoint(0); } else if (m_LowerSize == 0) { m_CurrentSize = m_StepSize; } else { m_CurrentSize = m_LowerSize; } // create the pair list ArrayList labeledTrainPairs = InstancePair.getPairs(train, m_CurrentSize, m_fractionMustLinks); System.out.println("Size of train: " + train.numInstances() + ", Total number of labeledTrainpairs: " + labeledTrainPairs.size()); while (m_CurrentSize <= maxTrainSize()) { // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_SplitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_CurrentSize; if(m_IsFraction) key[4] = "" + m_PlotPoints[pointNum]; System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { try { System.out.println((new SimpleDateFormat("HH:mm:ss:")).format(new Date())); if(m_IsFraction) { System.out.println("Run:" + run + " Fold:" + fold + " Size:" + m_CurrentSize + " Fraction:" + m_PlotPoints[pointNum]); } else { System.out.println("\n****\nRun:" + run + " Fold:" + fold + " Size:" + m_CurrentSize + " Dataset: " + train.relationName()); } // Need to remove the class labels from the unlabeledTrainSubsetWithLabels data before training learner Instances unlabeledTrain = new Instances(train);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -