📄 crossvalidationresultproducer.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* CrossValidationResultProducer.java
* Copyright (C) 1999 Len Trigg
*
*/
package weka.experiment;
import java.io.File;
import java.util.Calendar;
import java.util.Enumeration;
import java.util.Random;
import java.util.TimeZone;
import java.util.Vector;
import weka.core.AdditionalMeasureProducer;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
/**
* Generates for each run, carries out an n-fold cross-validation,
* using the set SplitEvaluator to generate some results. If the class
* attribute is nominal, the dataset is stratified. Results for each fold
* are generated, so you may wish to use this in addition with an
* AveragingResultProducer to obtain averages for each run.
*
* @author Len Trigg (trigg@cs.waikato.ac.nz)
* @version $Revision$
*/
public class CrossValidationResultProducer
implements ResultProducer, OptionHandler, AdditionalMeasureProducer {
/** The dataset of interest */
protected Instances m_Instances;
/** The ResultListener to send results to */
protected ResultListener m_ResultListener = new CSVResultListener();
/** The number of folds in the cross-validation */
protected int m_NumFolds = 10;
/** Save raw output of split evaluators --- for debugging purposes */
protected boolean m_debugOutput = false;
/** The output zipper to use for saving raw splitEvaluator output */
protected OutputZipper m_ZipDest = null;
/** The destination output file/directory for raw output */
protected File m_OutputFile = new File(
new File(System.getProperty("user.dir")),
"splitEvalutorOut.zip");
/** The SplitEvaluator used to generate results */
protected SplitEvaluator m_SplitEvaluator = new ClassifierSplitEvaluator();
/** The names of any additional measures to look for in SplitEvaluators */
protected String [] m_AdditionalMeasures = null;
/* The name of the key field containing the dataset name */
public static String DATASET_FIELD_NAME = "Dataset";
/* The name of the key field containing the run number */
public static String RUN_FIELD_NAME = "Run";
/* The name of the key field containing the fold number */
public static String FOLD_FIELD_NAME = "Fold";
/* The name of the result field containing the timestamp */
public static String TIMESTAMP_FIELD_NAME = "Date_time";
/**
* Returns a string describing this result producer
* @return a description of the result producer suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Performs a cross validation run using a supplied "
+"evaluator.";
}
/**
* Sets the dataset that results will be obtained for.
*
* @param instances a value of type 'Instances'.
*/
public void setInstances(Instances instances) {
m_Instances = instances;
}
/**
* Sets the object to send results of each run to.
*
* @param listener a value of type 'ResultListener'
*/
public void setResultListener(ResultListener listener) {
m_ResultListener = listener;
}
/**
* Set a list of method names for additional measures to look for
* in SplitEvaluators. This could contain many measures (of which only a
* subset may be produceable by the current SplitEvaluator) if an experiment
* is the type that iterates over a set of properties.
* @param additionalMeasures an array of measure names, null if none
*/
public void setAdditionalMeasures(String [] additionalMeasures) {
m_AdditionalMeasures = additionalMeasures;
if (m_SplitEvaluator != null) {
System.err.println("CrossValidationResultProducer: setting additional "
+"measures for "
+"split evaluator");
m_SplitEvaluator.setAdditionalMeasures(m_AdditionalMeasures);
}
}
/**
* Returns an enumeration of any additional measure names that might be
* in the SplitEvaluator
* @return an enumeration of the measure names
*/
public Enumeration emerateMeasures() {
Vector newVector = new Vector();
if (m_SplitEvaluator instanceof AdditionalMeasureProducer) {
Enumeration en = ((AdditionalMeasureProducer)m_SplitEvaluator).
emerateMeasures();
while (en.hasMoreElements()) {
String mname = (String)en.nextElement();
newVector.addElement(mname);
}
}
return newVector.elements();
}
/**
* Returns the value of the named measure
* @param measureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (m_SplitEvaluator instanceof AdditionalMeasureProducer) {
return ((AdditionalMeasureProducer)m_SplitEvaluator).
getMeasure(additionalMeasureName);
} else {
throw new IllegalArgumentException("CrossValidationResultProducer: "
+"Can't return value for : "+additionalMeasureName
+". "+m_SplitEvaluator.getClass().getName()+" "
+"is not an AdditionalMeasureProducer");
}
}
/**
* Gets a Double representing the current date and time.
* eg: 1:46pm on 20/5/1999 -> 19990520.1346
*
* @return a value of type Double
*/
public static Double getTimestamp() {
Calendar now = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
double timestamp = now.get(Calendar.YEAR) * 10000
+ (now.get(Calendar.MONTH) + 1) * 100
+ now.get(Calendar.DAY_OF_MONTH)
+ now.get(Calendar.HOUR_OF_DAY) / 100.0
+ now.get(Calendar.MINUTE) / 10000.0;
return new Double(timestamp);
}
/**
* Prepare to generate results.
*
* @exception Exception if an error occurs during preprocessing.
*/
public void preProcess() throws Exception {
if (m_SplitEvaluator == null) {
throw new Exception("No SplitEvalutor set");
}
if (m_ResultListener == null) {
throw new Exception("No ResultListener set");
}
m_ResultListener.preProcess(this);
}
/**
* Perform any postprocessing. When this method is called, it indicates
* that no more requests to generate results for the current experiment
* will be sent.
*
* @exception Exception if an error occurs
*/
public void postProcess() throws Exception {
m_ResultListener.postProcess(this);
if (m_debugOutput) {
if (m_ZipDest != null) {
m_ZipDest.finished();
m_ZipDest = null;
}
}
}
/**
* Gets the keys for a specified run number. Different run
* numbers correspond to different randomizations of the data. Keys
* produced should be sent to the current ResultListener
*
* @param run the run number to get keys for.
* @exception Exception if a problem occurs while getting the keys
*/
public void doRunKeys(int run) throws Exception {
if (m_Instances == null) {
throw new Exception("No Instances set");
}
/* // Randomize on a copy of the original dataset
Instances runInstances = new Instances(m_Instances);
runInstances.randomize(new Random(run));
if (runInstances.classAttribute().isNominal()) {
runInstances.stratify(m_NumFolds);
} */
for (int fold = 0; fold < m_NumFolds; fold++) {
// Add in some fields to the key like run and fold number, dataset name
Object [] seKey = m_SplitEvaluator.getKey();
Object [] key = new Object [seKey.length + 3];
key[0] = Utils.backQuoteChars(m_Instances.relationName());
key[1] = "" + run;
key[2] = "" + (fold + 1);
System.arraycopy(seKey, 0, key, 3, seKey.length);
if (m_ResultListener.isResultRequired(this, key)) {
try {
m_ResultListener.acceptResult(this, key, null);
} catch (Exception ex) {
// Save the train and test datasets for debugging purposes?
throw ex;
}
}
}
}
/**
* Gets the results for a specified run number. Different run
* numbers correspond to different randomizations of the data. Results
* produced should be sent to the current ResultListener
*
* @param run the run number to get results for.
* @exception Exception if a problem occurs while getting the results
*/
public void doRun(int run) throws Exception {
if (getRawOutput()) {
if (m_ZipDest == null) {
m_ZipDest = new OutputZipper(m_OutputFile);
}
}
if (m_Instances == null) {
throw new Exception("No Instances set");
}
// Randomize on a copy of the original dataset
Instances runInstances = new Instances(m_Instances);
Random random = new Random(run);
runInstances.randomize(random);
if (runInstances.classAttribute().isNominal()) {
runInstances.stratify(m_NumFolds);
}
for (int fold = 0; fold < m_NumFolds; fold++) {
// Add in some fields to the key like run and fold number, dataset name
Object [] seKey = m_SplitEvaluator.getKey();
Object [] key = new Object [seKey.length + 3];
key[0] = Utils.backQuoteChars(m_Instances.relationName());
key[1] = "" + run;
key[2] = "" + (fold + 1);
System.arraycopy(seKey, 0, key, 3, seKey.length);
if (m_ResultListener.isResultRequired(this, key)) {
Instances train = runInstances.trainCV(m_NumFolds, fold, random);
Instances test = runInstances.testCV(m_NumFolds, fold);
try {
Object [] seResults = m_SplitEvaluator.getResult(train, test);
Object [] results = new Object [seResults.length + 1];
results[0] = getTimestamp();
System.arraycopy(seResults, 0, results, 1,
seResults.length);
if (m_debugOutput) {
String resultName = (""+run+"."+(fold+1)+"."
+ Utils.backQuoteChars(runInstances.relationName())
+"."
+m_SplitEvaluator.toString()).replace(' ','_');
resultName = Utils.removeSubstring(resultName,
"weka.classifiers.");
resultName = Utils.removeSubstring(resultName,
"weka.filters.");
resultName = Utils.removeSubstring(resultName,
"weka.attributeSelection.");
m_ZipDest.zipit(m_SplitEvaluator.getRawResultOutput(), resultName);
}
m_ResultListener.acceptResult(this, key, results);
} catch (Exception ex) {
// Save the train and test datasets for debugging purposes?
throw ex;
}
}
}
}
/**
* Gets the names of each of the columns produced for a single run.
* This method should really be static.
*
* @return an array containing the name of each column
*/
public String [] getKeyNames() {
String [] keyNames = m_SplitEvaluator.getKeyNames();
// Add in the names of our extra key fields
String [] newKeyNames = new String [keyNames.length + 3];
newKeyNames[0] = DATASET_FIELD_NAME;
newKeyNames[1] = RUN_FIELD_NAME;
newKeyNames[2] = FOLD_FIELD_NAME;
System.arraycopy(keyNames, 0, newKeyNames, 3, keyNames.length);
return newKeyNames;
}
/**
* Gets the data types of each of the columns produced for a single run.
* This method should really be static.
*
* @return an array containing objects of the type of each column. The
* objects should be Strings, or Doubles.
*/
public Object [] getKeyTypes() {
Object [] keyTypes = m_SplitEvaluator.getKeyTypes();
// Add in the types of our extra fields
Object [] newKeyTypes = new String [keyTypes.length + 3];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -