📄 averagingresultproducer.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * AveragingResultProducer.java * Copyright (C) 1999 Len Trigg * */package weka.experiment;import weka.core.AdditionalMeasureProducer;import weka.core.FastVector;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.Utils;import java.util.Enumeration;import java.util.Hashtable;import java.util.Vector;/** <!-- globalinfo-start --> * Takes the results from a ResultProducer and submits the average to the result listener. Normally used with a CrossValidationResultProducer to perform n x m fold cross validation. For non-numeric result fields, the first value is used. * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -F <field name> * The name of the field to average over. * (default "Fold")</pre> * * <pre> -X <num results> * The number of results expected per average. * (default 10)</pre> * * <pre> -S * Calculate standard deviations. * (default only averages)</pre> * * <pre> -W <class name> * The full class name of a ResultProducer. * eg: weka.experiment.CrossValidationResultProducer</pre> * * <pre> * Options specific to result producer weka.experiment.CrossValidationResultProducer: * </pre> * * <pre> -X <number of folds> * The number of folds to use for the cross-validation. * (default 10)</pre> * * <pre> -D * Save raw split evaluator output.</pre> * * <pre> -O <file/directory name/path> * The filename where raw output will be stored. * If a directory name is specified then then individual * outputs will be gzipped, otherwise all output will be * zipped to the named file. Use in conjuction with -D. (default splitEvalutorOut.zip)</pre> * * <pre> -W <class name> * The full class name of a SplitEvaluator. * eg: weka.experiment.ClassifierSplitEvaluator</pre> * * <pre> * Options specific to split evaluator weka.experiment.ClassifierSplitEvaluator: * </pre> * * <pre> -W <class name> * The full class name of the classifier. * eg: weka.classifiers.bayes.NaiveBayes</pre> * * <pre> -C <index> * The index of the class for which IR statistics * are to be output. (default 1)</pre> * * <pre> -I <index> * The index of an attribute to output in the * results. This attribute should identify an * instance in order to know which instances are * in the test set of a cross validation. if 0 * no output (default 0).</pre> * * <pre> -P * Add target and prediction columns to the result * for each fold.</pre> * * <pre> * Options specific to classifier weka.classifiers.rules.ZeroR: * </pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * All options after -- will be passed to the result producer. * * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 1.16 $ */public class AveragingResultProducer implements ResultListener, ResultProducer, OptionHandler, AdditionalMeasureProducer { /** for serialization */ static final long serialVersionUID = 2551284958501991352L; /** The dataset of interest */ protected Instances m_Instances; /** The ResultListener to send results to */ protected ResultListener m_ResultListener = new CSVResultListener(); /** The ResultProducer used to generate results */ protected ResultProducer m_ResultProducer = new CrossValidationResultProducer(); /** The names of any additional measures to look for in SplitEvaluators */ protected String [] m_AdditionalMeasures = null; /** The number of results expected to average over for each run */ protected int m_ExpectedResultsPerAverage = 10; /** True if standard deviation fields should be produced */ protected boolean m_CalculateStdDevs; /** * The name of the field that will contain the number of results * averaged over. */ protected String m_CountFieldName = "Num_" + CrossValidationResultProducer .FOLD_FIELD_NAME; /** The name of the key field to average over */ protected String m_KeyFieldName = CrossValidationResultProducer .FOLD_FIELD_NAME; /** The index of the field to average over in the resultproducers key */ protected int m_KeyIndex = -1; /** Collects the keys from a single run */ protected FastVector m_Keys = new FastVector(); /** Collects the results from a single run */ protected FastVector m_Results = new FastVector(); /** * Returns a string describing this result producer * @return a description of the result producer suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Takes the results from a ResultProducer " +"and submits the average to the result listener. Normally used with " +"a CrossValidationResultProducer to perform n x m fold cross " +"validation. For non-numeric result fields, the first value is used."; } /** * Scans through the key field names of the result producer to find * the index of the key field to average over. Sets the value of * m_KeyIndex to the index, or -1 if no matching key field was found. * * @return the index of the key field to average over */ protected int findKeyIndex() { m_KeyIndex = -1; try { if (m_ResultProducer != null) { String [] keyNames = m_ResultProducer.getKeyNames(); for (int i = 0; i < keyNames.length; i++) { if (keyNames[i].equals(m_KeyFieldName)) { m_KeyIndex = i; break; } } } } catch (Exception ex) { } return m_KeyIndex; } /** * Determines if there are any constraints (imposed by the * destination) on the result columns to be produced by * resultProducers. Null should be returned if there are NO * constraints, otherwise a list of column names should be * returned as an array of Strings. * @param rp the ResultProducer to which the constraints will apply * @return an array of column names to which resutltProducer's * results will be restricted. * @throws Exception if constraints can't be determined */ public String [] determineColumnConstraints(ResultProducer rp) throws Exception { return null; } /** * Simulates a run to collect the keys the sub-resultproducer could * generate. Does some checking on the keys and determines the * template key. * * @param run the run number * @return a template key (null for the field being averaged) * @throws Exception if an error occurs */ protected Object [] determineTemplate(int run) throws Exception { if (m_Instances == null) { throw new Exception("No Instances set"); } m_ResultProducer.setInstances(m_Instances); // Clear the collected results m_Keys.removeAllElements(); m_Results.removeAllElements(); m_ResultProducer.doRunKeys(run); checkForMultipleDifferences(); Object [] template = (Object [])((Object [])m_Keys.elementAt(0)).clone(); template[m_KeyIndex] = null; // Check for duplicate keys checkForDuplicateKeys(template); return template; } /** * Gets the keys for a specified run number. Different run * numbers correspond to different randomizations of the data. Keys * produced should be sent to the current ResultListener * * @param run the run number to get keys for. * @throws Exception if a problem occurs while getting the keys */ public void doRunKeys(int run) throws Exception { // Generate the template Object [] template = determineTemplate(run); String [] newKey = new String [template.length - 1]; System.arraycopy(template, 0, newKey, 0, m_KeyIndex); System.arraycopy(template, m_KeyIndex + 1, newKey, m_KeyIndex, template.length - m_KeyIndex - 1); m_ResultListener.acceptResult(this, newKey, null); } /** * Gets the results for a specified run number. Different run * numbers correspond to different randomizations of the data. Results * produced should be sent to the current ResultListener * * @param run the run number to get results for. * @throws Exception if a problem occurs while getting the results */ public void doRun(int run) throws Exception { // Generate the key and ask whether the result is required Object [] template = determineTemplate(run); String [] newKey = new String [template.length - 1]; System.arraycopy(template, 0, newKey, 0, m_KeyIndex); System.arraycopy(template, m_KeyIndex + 1, newKey, m_KeyIndex, template.length - m_KeyIndex - 1); if (m_ResultListener.isResultRequired(this, newKey)) { // Clear the collected keys m_Keys.removeAllElements(); m_Results.removeAllElements(); m_ResultProducer.doRun(run); // Average the results collected //System.err.println("Number of results collected: " + m_Keys.size()); // Check that the keys only differ on the selected key field checkForMultipleDifferences(); template = (Object [])((Object [])m_Keys.elementAt(0)).clone(); template[m_KeyIndex] = null; // Check for duplicate keys checkForDuplicateKeys(template); // Calculate the average and submit it if necessary doAverageResult(template); } } /** * Compares a key to a template to see whether they match. Null * fields in the template are ignored in the matching. * * @param template the template to match against * @param test the key to test * @return true if the test key matches the template on all non-null template * fields */ protected boolean matchesTemplate(Object [] template, Object [] test) { if (template.length != test.length) { return false; } for (int i = 0; i < test.length; i++) { if ((template[i] != null) && (!template[i].equals(test[i]))) { return false; } } return true; } /** * Asks the resultlistener whether an average result is required, and * if so, calculates it. * * @param template the template to match keys against when calculating the * average * @throws Exception if an error occurs */ protected void doAverageResult(Object [] template) throws Exception { // Generate the key and ask whether the result is required String [] newKey = new String [template.length - 1]; System.arraycopy(template, 0, newKey, 0, m_KeyIndex); System.arraycopy(template, m_KeyIndex + 1, newKey, m_KeyIndex, template.length - m_KeyIndex - 1); if (m_ResultListener.isResultRequired(this, newKey)) { Object [] resultTypes = m_ResultProducer.getResultTypes(); Stats [] stats = new Stats [resultTypes.length]; for (int i = 0; i < stats.length; i++) { stats[i] = new Stats(); } Object [] result = getResultTypes(); int numMatches = 0; for (int i = 0; i < m_Keys.size(); i++) { Object [] currentKey = (Object [])m_Keys.elementAt(i); // Skip non-matching keys if (!matchesTemplate(template, currentKey)) { continue; } // Add the results to the stats accumulator Object [] currentResult = (Object [])m_Results.elementAt(i); numMatches++; for (int j = 0; j < resultTypes.length; j++) { if (resultTypes[j] instanceof Double) { if (currentResult[j] == null) { // set the stats object for this result to null--- // more than likely this is an additional measure field // not supported by the low level split evaluator if (stats[j] != null) { stats[j] = null; } /* throw new Exception("Null numeric result field found:\n" + DatabaseUtils.arrayToString(currentKey) + " -- " + DatabaseUtils .arrayToString(currentResult)); */ } if (stats[j] != null) { double currentVal = ((Double)currentResult[j]).doubleValue(); stats[j].add(currentVal); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -