noisecurvecrossvalidationresultproducer.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 1,346 行 · 第 1/3 页
JAVA
1,346 行
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * NoiseCurveCrossValidationResultProducer.java * * Project * CS 391L Machine Learning * * Nishit Shah (nishit@cs.utexas.edu) * * *//** * Notes: * Also Read attached README file * Noise will be input as PERCENT ONLY (eg: 10 20 30), No fractions supported * The grapher needs integer values of X axis in the learning curve * We Take Full Dataset for the runs * Use the 4th Key of Fraction as the Noise_Key -- Used Key_Noise_level * When we add Noise to a Feature, it does not include the Class attribute */package weka.experiment;import java.util.*;import java.io.*;import weka.experiment.*;import weka.core.*;import javax.swing.JComboBox;import javax.swing.ComboBoxModel;import javax.swing.DefaultComboBoxModel;/** * Does a N-fold cross-validation, but generates a Noise Curve * by also varying the number amount of Noise. Always uses the * same N-fold test set for testing. * * * @@author Raymond J. Mooney (mooney@@cs.utexas.edu) * Changed to plot Noise Curves */public class NoiseCurveCrossValidationResultProducer implements ResultProducer, OptionHandler, AdditionalMeasureProducer { /** The dataset of interest */ protected Instances m_Instances; /** The ResultListener to send results to */ protected ResultListener m_ResultListener = new CSVResultListener(); /** The number of folds in the cross-validation */ protected int m_NumFolds = 10; /** Save raw output of split evaluators --- for debugging purposes */ protected boolean m_debugOutput = false; /** Add noise to Class Labels in Training Set */ protected boolean m_classNoise = true; /** Add noise to Features, do not include Class as a Feature in Training Set */ protected boolean m_featureNoise = true; /** Set features missing, do not include Class as a Feature in Training Set */ protected boolean m_featureMiss = true; /** Add noise to Class Labels in Testing Set */ protected boolean m_classNoiseTest = true; /** Add noise to Features, do not include Class as a Feature in Testing Set */ protected boolean m_featureNoiseTest = true; /** Set features missing, do not include Class as a Feature in Testing Set */ protected boolean m_featureMissTest = true; /** The output zipper to use for saving raw splitEvaluator output */ protected OutputZipper m_ZipDest = null; /** The destination output file/directory for raw output */ protected File m_OutputFile = new File( new File(System.getProperty("user.dir")), "splitEvalutorOut.zip"); /** The SplitEvaluator used to generate results */ protected SplitEvaluator m_SplitEvaluator = new ClassifierSplitEvaluator(); /** The names of any additional measures to look for in SplitEvaluators */ protected String [] m_AdditionalMeasures = null; /** Store Statistics of Attributes */ protected Vector m_AttributeStats = null; /** The specific points to plot, either integers representing specific numbers of training examples, * or decimal fractions representing percentages of the full training set -- ONLY INTEGERS SUPPORTED*/ protected double[] m_PlotPoints; /** Dataset size for the runs, we take the full dataset*/ protected int m_CurrentSize = 0; /** Random Number, used for randomization in each run*/ protected Random m_Random = new Random(0); /* The name of the key field containing the dataset name */ public static String DATASET_FIELD_NAME = "Dataset"; /* The name of the key field containing the run number */ public static String RUN_FIELD_NAME = "Run"; /* The name of the key field containing the fold number */ public static String FOLD_FIELD_NAME = "Fold"; /* The name of the result field containing the timestamp */ public static String TIMESTAMP_FIELD_NAME = "Date_time"; /* The name of the key field containing the learning rate step number */ public static String STEP_FIELD_NAME = "Total_instances"; /* The name of the key field containing the fraction of total instances used */ public static String NOISE_FIELD_NAME = "Noise_levels"; /** * Returns a string describing this result producer * @@return a description of the result producer suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Performs a noise-curve cross validation run using a supplied " +"split evaluator. Trains on different amounts of noise in the Dataset, " +"repeatedly testing on the test set for that split after training."; } /** * Sets the dataset that results will be obtained for. * * @@param instances a value of type 'Instances'. */ public void setInstances(Instances instances) { m_Instances = instances; } /** * Sets the object to send results of each run to. * * @@param listener a value of type 'ResultListener' */ public void setResultListener(ResultListener listener) { m_ResultListener = listener; } /** * Set a list of method names for additional measures to look for * in SplitEvaluators. This could contain many measures (of which only a * subset may be produceable by the current SplitEvaluator) if an experiment * is the type that iterates over a set of properties. * @@param additionalMeasures an array of measure names, null if none */ public void setAdditionalMeasures(String [] additionalMeasures) { m_AdditionalMeasures = additionalMeasures; if (m_SplitEvaluator != null) { System.err.println("NoiseCurveCrossValidationResultProducer: setting additional " +"measures for " +"split evaluator"); m_SplitEvaluator.setAdditionalMeasures(m_AdditionalMeasures); } } /** * Returns an enumeration of any additional measure names that might be * in the SplitEvaluator * @@return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(); if (m_SplitEvaluator instanceof AdditionalMeasureProducer) { Enumeration en = ((AdditionalMeasureProducer)m_SplitEvaluator). enumerateMeasures(); while (en.hasMoreElements()) { String mname = (String)en.nextElement(); newVector.addElement(mname); } } return newVector.elements(); } /** * Returns the value of the named measure * @@param measureName the name of the measure to query for its value * @@return the value of the named measure * @@exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (m_SplitEvaluator instanceof AdditionalMeasureProducer) { return ((AdditionalMeasureProducer)m_SplitEvaluator). getMeasure(additionalMeasureName); } else { throw new IllegalArgumentException("NoiseCurveCrossValidationResultProducer: " +"Can't return value for : "+additionalMeasureName +". "+m_SplitEvaluator.getClass().getName()+" " +"is not an AdditionalMeasureProducer"); } } /** * Gets a Double representing the current date and time. * eg: 1:46pm on 20/5/1999 -> 19990520.1346 * * @@return a value of type Double */ public static Double getTimestamp() { Calendar now = Calendar.getInstance(TimeZone.getTimeZone("UTC")); double timestamp = now.get(Calendar.YEAR) * 10000 + (now.get(Calendar.MONTH) + 1) * 100 + now.get(Calendar.DAY_OF_MONTH) + now.get(Calendar.HOUR_OF_DAY) / 100.0 + now.get(Calendar.MINUTE) / 10000.0; return new Double(timestamp); } /** * Prepare to generate results. * * @@exception Exception if an error occurs during preprocessing. */ public void preProcess() throws Exception { if (m_SplitEvaluator == null) { throw new Exception("No SplitEvalutor set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } m_ResultListener.preProcess(this); } /** * Perform any postprocessing. When this method is called, it indicates * that no more requests to generate results for the current experiment * will be sent. * * @@exception Exception if an error occurs */ public void postProcess() throws Exception { m_ResultListener.postProcess(this); if (m_debugOutput) { if (m_ZipDest != null) { m_ZipDest.finished(); m_ZipDest = null; } } } /** * Gets the keys for a specified run number. Different run * numbers correspond to different randomizations of the data. Keys * produced should be sent to the current ResultListener * * @@param run the run number to get keys for. * @@exception Exception if a problem occurs while getting the keys */ public void doRunKeys(int run) throws Exception { int numExtraKeys; numExtraKeys = 5; if (m_Instances == null) { throw new Exception("No Instances set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } if (m_PlotPoints == null) { throw new Exception("Enter atleast one point on Noise Curve"); } for(int noiseLevel = 0; noiseLevel < m_PlotPoints.length; noiseLevel++) { for (int fold = 0; fold < m_NumFolds; fold++) { m_CurrentSize = maxTrainSize(); // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_SplitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_CurrentSize; key[4] = "" + (int) m_PlotPoints[noiseLevel]; //Converting to Integer for Grapher System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { try { m_ResultListener.acceptResult(this, key, null); } catch (Exception ex) { // Save the train and test datasets for debugging purposes? throw ex; } } }//for each fold }//for each noise level } /** * Get the maximum size of the training set based on * maximum training set size from the n-fold CV */ protected int maxTrainSize() { return (int)(m_Instances.numInstances()*(1 - 1/((double)m_NumFolds))); } /** * Gets the results for a specified run number. Different run * numbers correspond to different randomizations of the data. Results * produced should be sent to the current ResultListener * * @@param run the run number to get results for. * @@exception Exception if a problem occurs while getting the results */ public void doRun(int run) throws Exception { int numExtraKeys; numExtraKeys = 5; if (getRawOutput()) { if (m_ZipDest == null) { m_ZipDest = new OutputZipper(m_OutputFile); } } if (m_Instances == null) { throw new Exception("No Instances set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } // Check if PlotPoint is Null if (m_PlotPoints == null) { throw new Exception("Enter atleast one point on Noise Curve"); } m_AttributeStats = new Vector(m_Instances.numAttributes()); //Storing both Nominal and Numeric attributes, we use Numeric values for finding Mean and Variance for (int i = 0; i<m_Instances.numAttributes(); i++) { if(m_Instances.attribute(i).isNominal()){ int []nomCounts = (m_Instances.attributeStats(i)).nominalCounts; double []counts = new double[nomCounts.length]; double []stats = new double[counts.length - 1]; stats[0] = counts[0]; //Calculate cumulative probabilities for(int j=1; j<stats.length; j++) stats[j] = stats[j-1] + counts[j]; m_AttributeStats.add(i,stats); } if(m_Instances.attribute(i).isNumeric()) { double []stats = new double[2]; stats[0] = m_Instances.meanOrMode(i); stats[1] = Math.sqrt(m_Instances.variance(i)); m_AttributeStats.add(i, stats); } } //Initialize Random Number, The experiment will be repeatable for the same run number m_Random = new Random(run); // Randomize on a copy of the original dataset Instances runInstances = new Instances(m_Instances); runInstances.randomize(new Random(run)); if (runInstances.classAttribute().isNominal()) { runInstances.stratify(m_NumFolds); } //For Each Noise Level for (int noiseLevel = 0; noiseLevel < m_PlotPoints.length; noiseLevel++) { System.out.println("\n\nRun : " + run + " Number of Noise Levels : " + m_PlotPoints.length + " Noise Level : " + m_PlotPoints[noiseLevel] + "\n"); for (int fold = 0; fold < m_NumFolds; fold++) { Instances train = runInstances.trainCV(m_NumFolds, fold); // Randomly shuffle stratified training set for fold: added by Sugato train.randomize(new Random(fold)); Instances test = runInstances.testCV(m_NumFolds, fold); if (m_classNoise == true){ addClassNoise(train, test, noiseLevel); //System.out.println("Hi"); } // Check m_featureNoise, if true call addFeatureNoise(train, test) if (m_featureNoise == true){ addFeatureNoise(train, test, noiseLevel); //System.out.println("Hi"); } // Check m_featureMiss, if true call addFeatureMiss(train, test) if (m_featureMiss == true){ addFeatureMiss(train, test, noiseLevel); //System.out.println("Hi"); } m_CurrentSize = maxTrainSize(); // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_SplitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_CurrentSize; key[4] = "" + (int)m_PlotPoints[noiseLevel]; System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { try { System.out.println("Run:" + run + " Fold:" + fold + " Size:" + m_CurrentSize + " Noise Level:" + m_PlotPoints[noiseLevel]); Instances trainSubset = new Instances(train, 0, m_CurrentSize); Object [] seResults = m_SplitEvaluator.getResult(trainSubset, test); Object [] results = new Object [seResults.length + 1]; results[0] = getTimestamp(); System.arraycopy(seResults, 0, results, 1, seResults.length); if (m_debugOutput) { String resultName = (""+run+"."+(fold+1)+"."+ m_CurrentSize + "." + Utils.backQuoteChars(runInstances.relationName()) +"." +m_SplitEvaluator.toString()).replace(' ','_'); resultName = Utils.removeSubstring(resultName, "weka.classifiers."); resultName = Utils.removeSubstring(resultName, "weka.filters."); resultName = Utils.removeSubstring(resultName, "weka.attributeSelection."); m_ZipDest.zipit(m_SplitEvaluator.getRawResultOutput(), resultName); } m_ResultListener.acceptResult(this, key, results); } catch (Exception ex) { // Save the train and test datasets for debugging purposes? throw ex; } } }//Number of Folds }//For each Noise Level } /** Return the amount of noise for the ith point on the * curve for plotPoints as specified. Percent of NOISE Returned * Can Simplify this procedure to return m_PlotPoints[i] directly
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?