⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 averagingresultproducer.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    AveragingResultProducer.java
 *    Copyright (C) 1999 Len Trigg
 *
 */


package weka.experiment;

import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Vector;

import weka.core.AdditionalMeasureProducer;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;

/**
 * AveragingResultProducer takes the results from a ResultProducer
 * and submits the average to the result listener. For non-numeric
 * result fields, the first value is used.
 *
 * @author Len Trigg (trigg@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class AveragingResultProducer 
  implements ResultListener, ResultProducer, OptionHandler,
	     AdditionalMeasureProducer {

  /** The dataset of interest */
  protected Instances m_Instances;

  /** The ResultListener to send results to */
  protected ResultListener m_ResultListener = new CSVResultListener();

  /** The ResultProducer used to generate results */
  protected ResultProducer m_ResultProducer
    = new CrossValidationResultProducer();

  /** The names of any additional measures to look for in SplitEvaluators */
  protected String [] m_AdditionalMeasures = null;
  
  /** The number of results expected to average over for each run */
  protected int m_ExpectedResultsPerAverage = 10;

  /** True if standard deviation fields should be produced */
  protected boolean m_CalculateStdDevs;
    
  /**
   * The name of the field that will contain the number of results
   * averaged over.
   */
  protected String m_CountFieldName = "Num_" + CrossValidationResultProducer
    .FOLD_FIELD_NAME;

  /** The name of the key field to average over */
  protected String m_KeyFieldName = CrossValidationResultProducer
    .FOLD_FIELD_NAME;

  /** The index of the field to average over in the resultproducers key */
  protected int m_KeyIndex = -1;

  /** Collects the keys from a single run */
  protected FastVector m_Keys = new FastVector();
  
  /** Collects the results from a single run */
  protected FastVector m_Results = new FastVector();

  /**
   * Returns a string describing this result producer
   * @return a description of the result producer suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {
    return "Takes the results from a ResultProducer "
      +"and submits the average to the result listener. Normally used with "
      +"a CrossValidationResultProducer to perform n x m fold cross "
      +"validation.";
  }

  /**
   * Scans through the key field names of the result producer to find
   * the index of the key field to average over. Sets the value of
   * m_KeyIndex to the index, or -1 if no matching key field was found.
   *
   * @return the index of the key field to average over
   */
  protected int findKeyIndex() {

    m_KeyIndex = -1;
    try {
      if (m_ResultProducer != null) {
	String [] keyNames = m_ResultProducer.getKeyNames();
	for (int i = 0; i < keyNames.length; i++) {
	  if (keyNames[i].equals(m_KeyFieldName)) {
	    m_KeyIndex = i;
	    break;
	  }
	}
      }
    } catch (Exception ex) {
    }
    return m_KeyIndex;
  }

  /**
   * Determines if there are any constraints (imposed by the
   * destination) on the result columns to be produced by
   * resultProducers. Null should be returned if there are NO
   * constraints, otherwise a list of column names should be
   * returned as an array of Strings.
   * @param rp the ResultProducer to which the constraints will apply
   * @return an array of column names to which resutltProducer's
   * results will be restricted.
   * @exception Exception if constraints can't be determined
   */
  public String [] determineColumnConstraints(ResultProducer rp) 
    throws Exception {
    return null;
  }

  /**
   * Simulates a run to collect the keys the sub-resultproducer could
   * generate. Does some checking on the keys and determines the 
   * template key.
   *
   * @param run the run number
   * @return a template key (null for the field being averaged)
   * @exception Exception if an error occurs
   */
  protected Object [] determineTemplate(int run) throws Exception {

    if (m_Instances == null) {
      throw new Exception("No Instances set");
    }
    m_ResultProducer.setInstances(m_Instances);

    // Clear the collected results
    m_Keys.removeAllElements();
    m_Results.removeAllElements();
    
    m_ResultProducer.doRunKeys(run);
    checkForMultipleDifferences();

    Object [] template = (Object [])((Object [])m_Keys.elementAt(0)).clone();
    template[m_KeyIndex] = null;
    // Check for duplicate keys
    checkForDuplicateKeys(template);

    return template;
  }

  /**
   * Gets the keys for a specified run number. Different run
   * numbers correspond to different randomizations of the data. Keys
   * produced should be sent to the current ResultListener
   *
   * @param run the run number to get keys for.
   * @exception Exception if a problem occurs while getting the keys
   */
  public void doRunKeys(int run) throws Exception {

    // Generate the template
    Object [] template = determineTemplate(run);
    String [] newKey = new String [template.length - 1];
    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
    System.arraycopy(template, m_KeyIndex + 1,
		     newKey, m_KeyIndex,
		     template.length - m_KeyIndex - 1);
    m_ResultListener.acceptResult(this, newKey, null);      
  }

  /**
   * Gets the results for a specified run number. Different run
   * numbers correspond to different randomizations of the data. Results
   * produced should be sent to the current ResultListener
   *
   * @param run the run number to get results for.
   * @exception Exception if a problem occurs while getting the results
   */
  public void doRun(int run) throws Exception {

    // Generate the key and ask whether the result is required
    Object [] template = determineTemplate(run);
    String [] newKey = new String [template.length - 1];
    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
    System.arraycopy(template, m_KeyIndex + 1,
		     newKey, m_KeyIndex,
		     template.length - m_KeyIndex - 1);

    if (m_ResultListener.isResultRequired(this, newKey)) {
      // Clear the collected keys
      m_Keys.removeAllElements();
      m_Results.removeAllElements();
      
      m_ResultProducer.doRun(run);
      
      // Average the results collected
      //System.err.println("Number of results collected: " + m_Keys.size());
      
      // Check that the keys only differ on the selected key field
      checkForMultipleDifferences();
      
      template = (Object [])((Object [])m_Keys.elementAt(0)).clone();
      template[m_KeyIndex] = null;
      // Check for duplicate keys
      checkForDuplicateKeys(template);
      // Calculate the average and submit it if necessary
      doAverageResult(template);
    }
  }

  
  /**
   * Compares a key to a template to see whether they match. Null
   * fields in the template are ignored in the matching.
   *
   * @param template the template to match against
   * @param test the key to test
   * @return true if the test key matches the template on all non-null template
   * fields
   */
  protected boolean matchesTemplate(Object [] template, Object [] test) {
    
    if (template.length != test.length) {
      return false;
    }
    for (int i = 0; i < test.length; i++) {
      if ((template[i] != null) && (!template[i].equals(test[i]))) {
	return false;
      }
    }
    return true;
  }
  
  /**
   * Asks the resultlistener whether an average result is required, and
   * if so, calculates it.
   *
   * @param template the template to match keys against when calculating the
   * average
   * @exception Exception if an error occurs
   */
  protected void doAverageResult(Object [] template) throws Exception {

    // Generate the key and ask whether the result is required
    String [] newKey = new String [template.length - 1];
    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
    System.arraycopy(template, m_KeyIndex + 1,
		     newKey, m_KeyIndex,
		     template.length - m_KeyIndex - 1);
    if (m_ResultListener.isResultRequired(this, newKey)) {
      Object [] resultTypes = m_ResultProducer.getResultTypes();
      Stats [] stats = new Stats [resultTypes.length];
      for (int i = 0; i < stats.length; i++) {
	stats[i] = new Stats();
      }
      Object [] result = getResultTypes();
      int numMatches = 0;
      for (int i = 0; i < m_Keys.size(); i++) {
	Object [] currentKey = (Object [])m_Keys.elementAt(i);
	// Skip non-matching keys
	if (!matchesTemplate(template, currentKey)) {
	  continue;
	}
	// Add the results to the stats accumulator
	Object [] currentResult = (Object [])m_Results.elementAt(i);
	numMatches++;
	for (int j = 0; j < resultTypes.length; j++) {
	  if (resultTypes[j] instanceof Double) {
	    if (currentResult[j] == null) {

	      // set the stats object for this result to null---
	      // more than likely this is an additional measure field
	      // not supported by the low level split evaluator
	      if (stats[j] != null) {
		stats[j] = null;
	      }
	      
	      /* throw new Exception("Null numeric result field found:\n"
		 + DatabaseUtils.arrayToString(currentKey)
		 + " -- "
		 + DatabaseUtils
		 .arrayToString(currentResult)); */
	    }
	    if (stats[j] != null) {
	      double currentVal = ((Double)currentResult[j]).doubleValue();
	      stats[j].add(currentVal);
	    }
	  }
	}
      }
      if (numMatches != m_ExpectedResultsPerAverage) {
	throw new Exception("Expected " + m_ExpectedResultsPerAverage
			    + " results matching key \""
			    + DatabaseUtils.arrayToString(template)
			    + "\" but got "
			    + numMatches);
      }
      result[0] = new Double(numMatches);
      Object [] currentResult = (Object [])m_Results.elementAt(0);
      int k = 1;
      for (int j = 0; j < resultTypes.length; j++) {
	if (resultTypes[j] instanceof Double) {
	  if (stats[j] != null) {
	    stats[j].calculateDerived();
	    result[k++] = new Double(stats[j].mean);
	  } else {
	    result[k++] = null;
	  }
	  if (getCalculateStdDevs()) {
	    if (stats[j] != null) {
	      result[k++] = new Double(stats[j].stdDev);
	    } else {
	      result[k++] = null;
	    }
	  }
	} else {
	  result[k++] = currentResult[j];
	}
      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -