📄 averagingresultproducer.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* AveragingResultProducer.java
* Copyright (C) 1999 Len Trigg
*
*/
package weka.experiment;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Vector;
import weka.core.AdditionalMeasureProducer;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
/**
* AveragingResultProducer takes the results from a ResultProducer
* and submits the average to the result listener. For non-numeric
* result fields, the first value is used.
*
* @author Len Trigg (trigg@cs.waikato.ac.nz)
* @version $Revision$
*/
public class AveragingResultProducer
implements ResultListener, ResultProducer, OptionHandler,
AdditionalMeasureProducer {
/** The dataset of interest */
protected Instances m_Instances;
/** The ResultListener to send results to */
protected ResultListener m_ResultListener = new CSVResultListener();
/** The ResultProducer used to generate results */
protected ResultProducer m_ResultProducer
= new CrossValidationResultProducer();
/** The names of any additional measures to look for in SplitEvaluators */
protected String [] m_AdditionalMeasures = null;
/** The number of results expected to average over for each run */
protected int m_ExpectedResultsPerAverage = 10;
/** True if standard deviation fields should be produced */
protected boolean m_CalculateStdDevs;
/**
* The name of the field that will contain the number of results
* averaged over.
*/
protected String m_CountFieldName = "Num_" + CrossValidationResultProducer
.FOLD_FIELD_NAME;
/** The name of the key field to average over */
protected String m_KeyFieldName = CrossValidationResultProducer
.FOLD_FIELD_NAME;
/** The index of the field to average over in the resultproducers key */
protected int m_KeyIndex = -1;
/** Collects the keys from a single run */
protected FastVector m_Keys = new FastVector();
/** Collects the results from a single run */
protected FastVector m_Results = new FastVector();
/**
* Returns a string describing this result producer
* @return a description of the result producer suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Takes the results from a ResultProducer "
+"and submits the average to the result listener. Normally used with "
+"a CrossValidationResultProducer to perform n x m fold cross "
+"validation.";
}
/**
* Scans through the key field names of the result producer to find
* the index of the key field to average over. Sets the value of
* m_KeyIndex to the index, or -1 if no matching key field was found.
*
* @return the index of the key field to average over
*/
protected int findKeyIndex() {
m_KeyIndex = -1;
try {
if (m_ResultProducer != null) {
String [] keyNames = m_ResultProducer.getKeyNames();
for (int i = 0; i < keyNames.length; i++) {
if (keyNames[i].equals(m_KeyFieldName)) {
m_KeyIndex = i;
break;
}
}
}
} catch (Exception ex) {
}
return m_KeyIndex;
}
/**
* Determines if there are any constraints (imposed by the
* destination) on the result columns to be produced by
* resultProducers. Null should be returned if there are NO
* constraints, otherwise a list of column names should be
* returned as an array of Strings.
* @param rp the ResultProducer to which the constraints will apply
* @return an array of column names to which resutltProducer's
* results will be restricted.
* @exception Exception if constraints can't be determined
*/
public String [] determineColumnConstraints(ResultProducer rp)
throws Exception {
return null;
}
/**
* Simulates a run to collect the keys the sub-resultproducer could
* generate. Does some checking on the keys and determines the
* template key.
*
* @param run the run number
* @return a template key (null for the field being averaged)
* @exception Exception if an error occurs
*/
protected Object [] determineTemplate(int run) throws Exception {
if (m_Instances == null) {
throw new Exception("No Instances set");
}
m_ResultProducer.setInstances(m_Instances);
// Clear the collected results
m_Keys.removeAllElements();
m_Results.removeAllElements();
m_ResultProducer.doRunKeys(run);
checkForMultipleDifferences();
Object [] template = (Object [])((Object [])m_Keys.elementAt(0)).clone();
template[m_KeyIndex] = null;
// Check for duplicate keys
checkForDuplicateKeys(template);
return template;
}
/**
* Gets the keys for a specified run number. Different run
* numbers correspond to different randomizations of the data. Keys
* produced should be sent to the current ResultListener
*
* @param run the run number to get keys for.
* @exception Exception if a problem occurs while getting the keys
*/
public void doRunKeys(int run) throws Exception {
// Generate the template
Object [] template = determineTemplate(run);
String [] newKey = new String [template.length - 1];
System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
System.arraycopy(template, m_KeyIndex + 1,
newKey, m_KeyIndex,
template.length - m_KeyIndex - 1);
m_ResultListener.acceptResult(this, newKey, null);
}
/**
* Gets the results for a specified run number. Different run
* numbers correspond to different randomizations of the data. Results
* produced should be sent to the current ResultListener
*
* @param run the run number to get results for.
* @exception Exception if a problem occurs while getting the results
*/
public void doRun(int run) throws Exception {
// Generate the key and ask whether the result is required
Object [] template = determineTemplate(run);
String [] newKey = new String [template.length - 1];
System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
System.arraycopy(template, m_KeyIndex + 1,
newKey, m_KeyIndex,
template.length - m_KeyIndex - 1);
if (m_ResultListener.isResultRequired(this, newKey)) {
// Clear the collected keys
m_Keys.removeAllElements();
m_Results.removeAllElements();
m_ResultProducer.doRun(run);
// Average the results collected
//System.err.println("Number of results collected: " + m_Keys.size());
// Check that the keys only differ on the selected key field
checkForMultipleDifferences();
template = (Object [])((Object [])m_Keys.elementAt(0)).clone();
template[m_KeyIndex] = null;
// Check for duplicate keys
checkForDuplicateKeys(template);
// Calculate the average and submit it if necessary
doAverageResult(template);
}
}
/**
* Compares a key to a template to see whether they match. Null
* fields in the template are ignored in the matching.
*
* @param template the template to match against
* @param test the key to test
* @return true if the test key matches the template on all non-null template
* fields
*/
protected boolean matchesTemplate(Object [] template, Object [] test) {
if (template.length != test.length) {
return false;
}
for (int i = 0; i < test.length; i++) {
if ((template[i] != null) && (!template[i].equals(test[i]))) {
return false;
}
}
return true;
}
/**
* Asks the resultlistener whether an average result is required, and
* if so, calculates it.
*
* @param template the template to match keys against when calculating the
* average
* @exception Exception if an error occurs
*/
protected void doAverageResult(Object [] template) throws Exception {
// Generate the key and ask whether the result is required
String [] newKey = new String [template.length - 1];
System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
System.arraycopy(template, m_KeyIndex + 1,
newKey, m_KeyIndex,
template.length - m_KeyIndex - 1);
if (m_ResultListener.isResultRequired(this, newKey)) {
Object [] resultTypes = m_ResultProducer.getResultTypes();
Stats [] stats = new Stats [resultTypes.length];
for (int i = 0; i < stats.length; i++) {
stats[i] = new Stats();
}
Object [] result = getResultTypes();
int numMatches = 0;
for (int i = 0; i < m_Keys.size(); i++) {
Object [] currentKey = (Object [])m_Keys.elementAt(i);
// Skip non-matching keys
if (!matchesTemplate(template, currentKey)) {
continue;
}
// Add the results to the stats accumulator
Object [] currentResult = (Object [])m_Results.elementAt(i);
numMatches++;
for (int j = 0; j < resultTypes.length; j++) {
if (resultTypes[j] instanceof Double) {
if (currentResult[j] == null) {
// set the stats object for this result to null---
// more than likely this is an additional measure field
// not supported by the low level split evaluator
if (stats[j] != null) {
stats[j] = null;
}
/* throw new Exception("Null numeric result field found:\n"
+ DatabaseUtils.arrayToString(currentKey)
+ " -- "
+ DatabaseUtils
.arrayToString(currentResult)); */
}
if (stats[j] != null) {
double currentVal = ((Double)currentResult[j]).doubleValue();
stats[j].add(currentVal);
}
}
}
}
if (numMatches != m_ExpectedResultsPerAverage) {
throw new Exception("Expected " + m_ExpectedResultsPerAverage
+ " results matching key \""
+ DatabaseUtils.arrayToString(template)
+ "\" but got "
+ numMatches);
}
result[0] = new Double(numMatches);
Object [] currentResult = (Object [])m_Results.elementAt(0);
int k = 1;
for (int j = 0; j < resultTypes.length; j++) {
if (resultTypes[j] instanceof Double) {
if (stats[j] != null) {
stats[j].calculateDerived();
result[k++] = new Double(stats[j].mean);
} else {
result[k++] = null;
}
if (getCalculateStdDevs()) {
if (stats[j] != null) {
result[k++] = new Double(stats[j].stdDev);
} else {
result[k++] = null;
}
}
} else {
result[k++] = currentResult[j];
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -