📄 semisupincompletelabelcurvecvresultproducer.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * SemiSupIncompleteLabelCurveCVResultProducer.java * Copyright (C) 2002 Sugato Basu * */package weka.experiment;import java.util.*;import java.io.*;import weka.core.Instance;import weka.core.Instances;import weka.core.OptionHandler;import weka.core.Option;import weka.core.Utils;import weka.core.AdditionalMeasureProducer;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;/** * N-fold cross-validation learning curve for semi-supervised learners * (clusterers and classifiers), where labeled data is not present for * all the categories * * @author Sugato Basu */public class SemiSupIncompleteLabelCurveCVResultProducer implements ResultProducer, OptionHandler, AdditionalMeasureProducer { /** The dataset of interest */ protected Instances m_Instances; /** The ResultListener to send results to */ protected ResultListener m_ResultListener = new CSVResultListener(); /** The number of folds in the cross-validation */ protected int m_NumFolds = 10; /** Whether transductive evaluation is to be performed */ protected boolean m_IsTransductive = true; /** Save raw output of split evaluators --- for debugging purposes */ protected boolean m_debugOutput = false; /** The output zipper to use for saving raw splitEvaluator output */ protected OutputZipper m_ZipDest = null; /** The destination output file/directory for raw output */ protected File m_OutputFile = new File( new File(System.getProperty("user.dir")), "splitEvalutorOut.zip"); /** The SplitEvaluator used to generate results */ protected SplitEvaluator m_SplitEvaluator = new SemiSupClustererSplitEvaluator(); /** The names of any additional measures to look for in SplitEvaluators */ protected String [] m_AdditionalMeasures = null; /** * The minimum number of labeled categories to drop. If this is zero, the first * step will drop m_StepSize categories */ protected int m_LowerSize = 0; /** * The maximum number of labeled categories to drop. -1 indicates no maximum * (other than the total number of categories) */ protected int m_UpperSize = -1; /** The number of labeled categories to drop at each step */ protected int m_StepSize = 2; /** * The specific points to plot, integers representing specific * numbers of incomplete labels */ protected double[] m_PlotPoints; /** The current labeled training dataset size (as fraction of totalSize) */ protected int m_CurrentSize = 0; /** Number of categories for which labeled data is not provided */ protected int m_NumMissingLabels = 0; /* The name of the key field containing the dataset name */ public static String DATASET_FIELD_NAME = "Dataset"; /* The name of the key field containing the run number */ public static String RUN_FIELD_NAME = "Run"; /* The name of the key field containing the fold number */ public static String FOLD_FIELD_NAME = "Fold"; /* The name of the result field containing the timestamp */ public static String TIMESTAMP_FIELD_NAME = "Date_time"; /* The name of the key field containing the learning rate step number */ public static String STEP_FIELD_NAME = "Total_instances"; /* The name of the key field containing the fraction of total instances used */ public static String FRACTION_FIELD_NAME = "Fraction_instances"; /* Indicates whether fractions or actual number of instances have been specified */ protected boolean m_IsFraction = false; /** * Returns a string describing this result producer * @return a description of the result producer suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Performs a semi-supervised learning-curve cross validation run using a supplied semi-supervised split evaluator, where labeled data is not avaiable for all the categories. In the inductive framework, the semi-supervised learner for each split is trained on a fixed size of labeled + unlabeled training data, with the proportion of incomplete labels in the training data being increased at each point along the learning curve. Testing is performed on the test set for that split after training, only on the test data with labels that had been removed in the training data. In the transductive framework, the unlabeled testing data is also added to the pool of unlabeled training data, and as in the inductive framework the proportion of incomplete labels in the training data is increased at each point along the learning curve. Testing is performed as usual on the test set for that split after training. "; } /** * Sets the dataset that results will be obtained for. * * @param instances a value of type 'Instances'. */ public void setInstances(Instances instances) { m_Instances = instances; } /** * Sets the object to send results of each run to. * * @param listener a value of type 'ResultListener' */ public void setResultListener(ResultListener listener) { m_ResultListener = listener; } /** * Set a list of method names for additional measures to look for * in SplitEvaluators. This could contain many measures (of which only a * subset may be produceable by the current SplitEvaluator) if an experiment * is the type that iterates over a set of properties. * @param additionalMeasures an array of measure names, null if none */ public void setAdditionalMeasures(String [] additionalMeasures) { m_AdditionalMeasures = additionalMeasures; if (m_SplitEvaluator != null) { System.err.println(" SemiSupIncompleteLabelCurveCVResultProducer: setting additional " +"measures for " +"split evaluator"); m_SplitEvaluator.setAdditionalMeasures(m_AdditionalMeasures); } } /** * Returns an enumeration of any additional measure names that might be * in the SplitEvaluator * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(); if (m_SplitEvaluator instanceof AdditionalMeasureProducer) { Enumeration en = ((AdditionalMeasureProducer)m_SplitEvaluator). enumerateMeasures(); while (en.hasMoreElements()) { String mname = (String)en.nextElement(); newVector.addElement(mname); } } return newVector.elements(); } /** * Returns the value of the named measure * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (m_SplitEvaluator instanceof AdditionalMeasureProducer) { return ((AdditionalMeasureProducer)m_SplitEvaluator). getMeasure(additionalMeasureName); } else { throw new IllegalArgumentException(" SemiSupIncompleteLabelCurveCVResultProducer: " +"Can't return value for : "+additionalMeasureName +". "+m_SplitEvaluator.getClass().getName()+" " +"is not an AdditionalMeasureProducer"); } } /** * Gets a Double representing the current date and time. * eg: 1:46pm on 20/5/1999 -> 19990520.1346 * * @return a value of type Double */ public static Double getTimestamp() { Calendar now = Calendar.getInstance(TimeZone.getTimeZone("UTC")); double timestamp = now.get(Calendar.YEAR) * 10000 + (now.get(Calendar.MONTH) + 1) * 100 + now.get(Calendar.DAY_OF_MONTH) + now.get(Calendar.HOUR_OF_DAY) / 100.0 + now.get(Calendar.MINUTE) / 10000.0; return new Double(timestamp); } /** * Prepare to generate results. * * @exception Exception if an error occurs during preprocessing. */ public void preProcess() throws Exception { if (m_SplitEvaluator == null) { throw new Exception("No SplitEvalutor set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } m_ResultListener.preProcess(this); } /** * Perform any postprocessing. When this method is called, it indicates * that no more requests to generate results for the current experiment * will be sent. * * @exception Exception if an error occurs */ public void postProcess() throws Exception { m_ResultListener.postProcess(this); if (m_debugOutput) { if (m_ZipDest != null) { m_ZipDest.finished(); m_ZipDest = null; } } } /** * Gets the keys for a specified run number. Different run * numbers correspond to different randomizations of the data. Keys * produced should be sent to the current ResultListener * * @param run the run number to get keys for. * @exception Exception if a problem occurs while getting the keys */ public void doRunKeys(int run) throws Exception { int numExtraKeys; if(m_IsFraction) numExtraKeys = 5; else numExtraKeys = 4; if (m_Instances == null) { throw new Exception("No Instances set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } for (int fold = 0; fold < m_NumFolds; fold++) { int pointNum = 0; // For each subsample size if (m_PlotPoints != null) { m_NumMissingLabels = plotPoint(0); } else if (m_LowerSize == 0) { m_NumMissingLabels = m_StepSize; } else { m_NumMissingLabels = m_LowerSize; } while (m_NumMissingLabels < m_Instances.numClasses()) { // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_SplitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_NumMissingLabels; if(m_IsFraction) key[4] = "" + m_PlotPoints[pointNum]; System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { try { m_ResultListener.acceptResult(this, key, null); } catch (Exception ex) { // Save the train and test datasets for debugging purposes? throw ex; } } if (m_PlotPoints != null) { pointNum ++; m_NumMissingLabels = plotPoint(pointNum); } else { m_NumMissingLabels += m_StepSize; } } } } protected int maxTrainSize() { if (m_UpperSize == -1 || m_PlotPoints != null) return (int)(m_Instances.numInstances()*(1 - 1/((double)m_NumFolds))); else return m_UpperSize; } /** * Gets the results for a specified run number. Different run * numbers correspond to different randomizations of the data. Results * produced should be sent to the current ResultListener * * @param run the run number to get results for. * @exception Exception if a problem occurs while getting the results */ public void doRun(int run) throws Exception { int numExtraKeys; int numClasses = m_Instances.numClasses(); if(m_IsFraction) numExtraKeys = 5; else numExtraKeys = 4; if (getRawOutput()) { if (m_ZipDest == null) { m_ZipDest = new OutputZipper(m_OutputFile); } } if (m_Instances == null) { throw new Exception("No Instances set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } // Randomize on a copy of the original dataset Instances runInstances = new Instances(m_Instances); runInstances.randomize(new Random(run)); if (runInstances.classAttribute().isNominal()) { runInstances.stratify(m_NumFolds); } for (int fold = 0; fold < m_NumFolds; fold++) { Instances train = runInstances.trainCV(m_NumFolds, fold); // Randomly shuffle stratified training set for fold train.randomize(new Random(fold)); Instances testSet = runInstances.testCV(m_NumFolds, fold); /* for (int i=0; i<train.numInstances(); i++) { System.out.println("Train instance has class: " + train.instance(i).classValue()); } for (int i=0; i<testSet.numInstances(); i++) { System.out.println("Test instance has class: " + testSet.instance(i).classValue()); } */ // For each subsample size int pointNum = 0; // For each subsample size if (m_PlotPoints != null) { m_NumMissingLabels = plotPoint(0); } else if (m_LowerSize == 0) { m_NumMissingLabels = m_StepSize; } else { m_NumMissingLabels = m_LowerSize; } while (m_NumMissingLabels < numClasses) { // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_SplitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_NumMissingLabels; if(m_IsFraction) key[4] = "" + m_PlotPoints[pointNum]; System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { try { if(m_IsFraction) System.out.println("Run:" + run + " Fold:" + fold + " Size:" + m_CurrentSize + " Fraction:" + m_PlotPoints[pointNum] + " NumMissingLablels:" + m_NumMissingLabels);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -