📄 dedupingprcurvecvresultproducersplit.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * DedupingPRCurveCVResultProducer.java * Copyright (c) 2003 Mikhail Bilenko * */package weka.experiment;import java.util.*;import java.io.*;import weka.core.*;import weka.core.OptionHandler;import weka.core.Option;import weka.core.Utils;import weka.core.AdditionalMeasureProducer;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;/** * N-fold cross-validation learning curve for * deduping applications * * @author Mikhail Bilenko */public class DedupingPRCurveCVResultProducerSplit implements ResultProducer, OptionHandler, AdditionalMeasureProducer { /** The dataset of interest */ protected Instances m_instances; /** SVM-light can work in classification, regression and preference ranking modes */ public static final int FOLD_CREATION_MODE_STRATIFIED = 1; public static final int FOLD_CREATION_MODE_RANDOM = 2; public static final Tag[] TAGS_FOLD_CREATION_MODE = { new Tag(FOLD_CREATION_MODE_STRATIFIED, "Stratified"), new Tag(FOLD_CREATION_MODE_RANDOM, "Random") }; protected int m_foldCreationMode = FOLD_CREATION_MODE_STRATIFIED; /** The ResultListener to send results to */ protected ResultListener m_resultListener = new CSVResultListener(); /** The number of folds in the cross-validation */ protected int m_numFolds = 2; /** Save raw output of split evaluators --- for debugging purposes */ protected boolean m_debugOutput = false; /** The output zipper to use for saving raw splitEvaluator output */ protected OutputZipper m_zipDest = null; /** The destination output file/directory for raw output */ protected File m_outputFile = new File( new File(System.getProperty("user.dir")), "splitEvalutorOut.zip"); /** The separate training file if desired */ protected String m_separateTrainingFile = new String(""); /** The SplitEvaluator used to generate results */ protected SplitEvaluator m_splitEvaluator = new DeduperSplitEvaluator(); /** The names of any additional measures to look for in SplitEvaluators */ protected String [] m_additionalMeasures = null; /** The specific points to plot, either integers representing specific numbers of training examples, * or decimal fractions representing percentages of the full training set*/ protected double[] m_plotPoints = {0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0}; /* The name of the key field containing the dataset name */ public static String DATASET_FIELD_NAME = "Dataset"; /* The name of the key field containing the run number */ public static String RUN_FIELD_NAME = "Run"; /* The name of the key field containing the fold number */ public static String FOLD_FIELD_NAME = "Fold"; /* The name of the result field containing the timestamp */ public static String TIMESTAMP_FIELD_NAME = "Date_time"; /* The name of the key field containing the learning rate step number */ public static String RECALL_FIELD_NAME = "Fraction_instances"; /** * Returns a string describing this result producer * @return a description of the result producer suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Performs a learning-curve cross validation run using a supplied " + "deduping split evaluator. Trains on increasing subsets of the training data for each split"; } /** * Sets the dataset that results will be obtained for. * * @param instances a value of type 'Instances'. */ public void setInstances(Instances instances) { m_instances = instances; } /** * Sets the object to send results of each run to. * * @param listener a value of type 'ResultListener' */ public void setResultListener(ResultListener listener) { m_resultListener = listener; } /** * Set a list of method names for additional measures to look for * in SplitEvaluators. This could contain many measures (of which only a * subset may be produceable by the current SplitEvaluator) if an experiment * is the type that iterates over a set of properties. * @param additionalMeasures an array of measure names, null if none */ public void setAdditionalMeasures(String [] additionalMeasures) { m_additionalMeasures = additionalMeasures; if (m_splitEvaluator != null) { System.err.println(" DedupingPRCurveCVResultProducerSplit: setting additional " +"measures for " +"split evaluator"); m_splitEvaluator.setAdditionalMeasures(m_additionalMeasures); } } /** * Returns an enumeration of any additional measure names that might be * in the SplitEvaluator * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(); if (m_splitEvaluator instanceof AdditionalMeasureProducer) { Enumeration en = ((AdditionalMeasureProducer)m_splitEvaluator). enumerateMeasures(); while (en.hasMoreElements()) { String mname = (String)en.nextElement(); newVector.addElement(mname); } } return newVector.elements(); } /** * Returns the value of the named measure * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (m_splitEvaluator instanceof AdditionalMeasureProducer) { return ((AdditionalMeasureProducer)m_splitEvaluator). getMeasure(additionalMeasureName); } else { throw new IllegalArgumentException("DedupingPRCurveCVResultProducerSplit: " +"Can't return value for : "+additionalMeasureName +". "+m_splitEvaluator.getClass().getName()+" " +"is not an AdditionalMeasureProducer"); } } /** * Gets a Double representing the current date and time. * eg: 1:46pm on 20/5/1999 -> 19990520.1346 * * @return a value of type Double */ public static Double getTimestamp() { Calendar now = Calendar.getInstance(TimeZone.getTimeZone("UTC")); double timestamp = now.get(Calendar.YEAR) * 10000 + (now.get(Calendar.MONTH) + 1) * 100 + now.get(Calendar.DAY_OF_MONTH) + now.get(Calendar.HOUR_OF_DAY) / 100.0 + now.get(Calendar.MINUTE) / 10000.0; return new Double(timestamp); } /** * Prepare to generate results. * * @exception Exception if an error occurs during preprocessing. */ public void preProcess() throws Exception { if (m_splitEvaluator == null) { throw new Exception("No SplitEvalutor set"); } if (m_resultListener == null) { throw new Exception("No ResultListener set"); } m_resultListener.preProcess(this); } /** * Perform any postprocessing. When this method is called, it indicates * that no more requests to generate results for the current experiment * will be sent. * * @exception Exception if an error occurs */ public void postProcess() throws Exception { m_resultListener.postProcess(this); if (m_debugOutput) { if (m_zipDest != null) { m_zipDest.finished(); m_zipDest = null; } } } /** * Gets the keys for a specified run number. Different run * numbers correspond to different randomizations of the data. Keys * produced should be sent to the current ResultListener * * @param run the run number to get keys for. * @exception Exception if a problem occurs while getting the keys */ public void doRunKeys(int run) throws Exception { int numExtraKeys = 4; if (m_instances == null) { throw new Exception("No Instances set"); } if (m_resultListener == null) { throw new Exception("No ResultListener set"); } for (int fold = 0; fold < m_numFolds; fold++) { int pointNum = 0; // For each subsample size for (int i = 0; i < m_plotPoints.length; i++) { // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_splitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_plotPoints[i]; System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_resultListener.isResultRequired(this, key)) { try { m_resultListener.acceptResult(this, key, null); } catch (Exception ex) { // Save the train and test datasets for debugging purposes? throw ex; } } } } } /** * Gets the results for a specified run number. Different run * numbers correspond to different randomizations of the data. Results * produced should be sent to the current ResultListener * * @param run the run number to get results for. * @exception Exception if a problem occurs while getting the results */ public void doRun(int run) throws Exception { int numExtraKeys = 4; if (getRawOutput()) { if (m_zipDest == null) { m_zipDest = new OutputZipper(m_outputFile); } } if (m_instances == null) { throw new Exception("No Instances set"); } if (m_resultListener == null) { throw new Exception("No ResultListener set"); } if (!m_instances.classAttribute().isNominal()) { throw new Exception("Class attribute must be nominal - it is the true Object ID"); } // Randomize on a copy of the original dataset Instances runInstances = new Instances(m_instances); runInstances.randomize(new Random(run)); ArrayList foldList = createFoldList(runInstances, m_numFolds); // If a separate training file is used, create separate folds for it ArrayList sepFoldList = null; if (m_separateTrainingFile.length() > 0) { Instances trainInstances = new Instances(new BufferedReader(new FileReader(m_separateTrainingFile))); trainInstances.setClassIndex(trainInstances.numAttributes() - 1); trainInstances.randomize(new Random(run)); sepFoldList = createFoldList(trainInstances, m_numFolds); System.out.println("Got separate training file " + m_separateTrainingFile + " of " + trainInstances.numInstances() + " instances"); } for (int fold = 0; fold < m_numFolds; fold++) { Instances train = ((sepFoldList == null) ? getTrainingFold(foldList, fold) : getTrainingFold(sepFoldList, fold)); // Randomly shuffle the training set for fold creation train.randomize(new Random(fold)); Instances test = (Instances) foldList.get(fold); System.out.println("Run:" + run + " Fold:" + fold + " TestSize=" + test.numInstances()); Object[] prResults = m_splitEvaluator.getResult(train, test); for (int i = 0; i < m_plotPoints.length; i++) { // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_splitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_plotPoints[i]; System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_resultListener.isResultRequired(this, key)) { try { Object [] seResults = processResults(prResults, m_plotPoints[i]); System.out.println("Adding result: RLevel=" + m_plotPoints[i] + "\tR=" + seResults[1] + "\tP=" + seResults[2] + "\tFM=" + seResults[3]); Object [] results = new Object [seResults.length + 1]; results[0] = getTimestamp(); System.arraycopy(seResults, 0, results, 1, seResults.length); if (m_debugOutput) { String resultName = (""+run+"."+(fold+1)+"."+ "." + Utils.backQuoteChars(runInstances.relationName()) +"." +m_splitEvaluator.toString()).replace(' ','_'); resultName = Utils.removeSubstring(resultName,
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -