⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 averagingresultproducer.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    AveragingResultProducer.java *    Copyright (C) 1999 Len Trigg * */package weka.experiment;import java.io.*;import java.util.*;import java.sql.*;import java.net.*;import weka.core.OptionHandler;import weka.core.Instances;import weka.core.FastVector;import weka.core.Utils;import weka.core.Option;import weka.core.AdditionalMeasureProducer;/** * AveragingResultProducer takes the results from a ResultProducer * and submits the average to the result listener. For non-numeric * result fields, the first value is used. * * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class AveragingResultProducer   implements ResultListener, ResultProducer, OptionHandler,	     AdditionalMeasureProducer {  /** The dataset of interest */  protected Instances m_Instances;  /** The ResultListener to send results to */  protected ResultListener m_ResultListener = new CSVResultListener();  /** The ResultProducer used to generate results */  protected ResultProducer m_ResultProducer    = new CrossValidationResultProducer();  /** The names of any additional measures to look for in SplitEvaluators */  protected String [] m_AdditionalMeasures = null;    /** The number of results expected to average over for each run */  protected int m_ExpectedResultsPerAverage = 10;  /** True if standard deviation fields should be produced */  protected boolean m_CalculateStdDevs;      /**   * The name of the field that will contain the number of results   * averaged over.   */  protected String m_CountFieldName = "Num_" + CrossValidationResultProducer    .FOLD_FIELD_NAME;  /** The name of the key field to average over */  protected String m_KeyFieldName = CrossValidationResultProducer    .FOLD_FIELD_NAME;  /** The index of the field to average over in the resultproducers key */  protected int m_KeyIndex = -1;  /** Collects the keys from a single run */  protected FastVector m_Keys = new FastVector();    /** Collects the results from a single run */  protected FastVector m_Results = new FastVector();  /**   * Returns a string describing this result producer   * @return a description of the result producer suitable for   * displaying in the explorer/experimenter gui   */  public String globalInfo() {    return "Takes the results from a ResultProducer "      +"and submits the average to the result listener. Normally used with "      +"a CrossValidationResultProducer to perform n x m fold cross "      +"validation.";  }  /**   * Scans through the key field names of the result producer to find   * the index of the key field to average over. Sets the value of   * m_KeyIndex to the index, or -1 if no matching key field was found.   *   * @return the index of the key field to average over   */  protected int findKeyIndex() {    m_KeyIndex = -1;    try {      if (m_ResultProducer != null) {	String [] keyNames = m_ResultProducer.getKeyNames();	for (int i = 0; i < keyNames.length; i++) {	  if (keyNames[i].equals(m_KeyFieldName)) {	    m_KeyIndex = i;	    break;	  }	}      }    } catch (Exception ex) {    }    return m_KeyIndex;  }  /**   * Determines if there are any constraints (imposed by the   * destination) on the result columns to be produced by   * resultProducers. Null should be returned if there are NO   * constraints, otherwise a list of column names should be   * returned as an array of Strings.   * @param rp the ResultProducer to which the constraints will apply   * @return an array of column names to which resutltProducer's   * results will be restricted.   * @exception Exception if constraints can't be determined   */  public String [] determineColumnConstraints(ResultProducer rp)     throws Exception {    return null;  }  /**   * Simulates a run to collect the keys the sub-resultproducer could   * generate. Does some checking on the keys and determines the    * template key.   *   * @param run the run number   * @return a template key (null for the field being averaged)   * @exception Exception if an error occurs   */  protected Object [] determineTemplate(int run) throws Exception {    if (m_Instances == null) {      throw new Exception("No Instances set");    }    m_ResultProducer.setInstances(m_Instances);    // Clear the collected results    m_Keys.removeAllElements();    m_Results.removeAllElements();        m_ResultProducer.doRunKeys(run);    checkForMultipleDifferences();    Object [] template = (Object [])((Object [])m_Keys.elementAt(0)).clone();    template[m_KeyIndex] = null;    // Check for duplicate keys    checkForDuplicateKeys(template);    return template;  }  /**   * Gets the keys for a specified run number. Different run   * numbers correspond to different randomizations of the data. Keys   * produced should be sent to the current ResultListener   *   * @param run the run number to get keys for.   * @exception Exception if a problem occurs while getting the keys   */  public void doRunKeys(int run) throws Exception {    // Generate the template    Object [] template = determineTemplate(run);    String [] newKey = new String [template.length - 1];    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);    System.arraycopy(template, m_KeyIndex + 1,		     newKey, m_KeyIndex,		     template.length - m_KeyIndex - 1);    m_ResultListener.acceptResult(this, newKey, null);        }  /**   * Gets the results for a specified run number. Different run   * numbers correspond to different randomizations of the data. Results   * produced should be sent to the current ResultListener   *   * @param run the run number to get results for.   * @exception Exception if a problem occurs while getting the results   */  public void doRun(int run) throws Exception {    // Generate the key and ask whether the result is required    Object [] template = determineTemplate(run);    String [] newKey = new String [template.length - 1];    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);    System.arraycopy(template, m_KeyIndex + 1,		     newKey, m_KeyIndex,		     template.length - m_KeyIndex - 1);    if (m_ResultListener.isResultRequired(this, newKey)) {      // Clear the collected keys      m_Keys.removeAllElements();      m_Results.removeAllElements();            m_ResultProducer.doRun(run);            // Average the results collected      //System.err.println("Number of results collected: " + m_Keys.size());            // Check that the keys only differ on the selected key field      checkForMultipleDifferences();            template = (Object [])((Object [])m_Keys.elementAt(0)).clone();      template[m_KeyIndex] = null;      // Check for duplicate keys      checkForDuplicateKeys(template);      // Calculate the average and submit it if necessary      doAverageResult(template);    }  }    /**   * Compares a key to a template to see whether they match. Null   * fields in the template are ignored in the matching.   *   * @param template the template to match against   * @param test the key to test   * @return true if the test key matches the template on all non-null template   * fields   */  protected boolean matchesTemplate(Object [] template, Object [] test) {        if (template.length != test.length) {      return false;    }    for (int i = 0; i < test.length; i++) {      if ((template[i] != null) && (!template[i].equals(test[i]))) {	return false;      }    }    return true;  }    /**   * Asks the resultlistener whether an average result is required, and   * if so, calculates it.   *   * @param template the template to match keys against when calculating the   * average   * @exception Exception if an error occurs   */  protected void doAverageResult(Object [] template) throws Exception {    // Generate the key and ask whether the result is required    String [] newKey = new String [template.length - 1];    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);    System.arraycopy(template, m_KeyIndex + 1,		     newKey, m_KeyIndex,		     template.length - m_KeyIndex - 1);    if (m_ResultListener.isResultRequired(this, newKey)) {      Object [] resultTypes = m_ResultProducer.getResultTypes();      Stats [] stats = new Stats [resultTypes.length];      for (int i = 0; i < stats.length; i++) {	stats[i] = new Stats();      }      Object [] result = getResultTypes();      int numMatches = 0;      for (int i = 0; i < m_Keys.size(); i++) {	Object [] currentKey = (Object [])m_Keys.elementAt(i);	// Skip non-matching keys	if (!matchesTemplate(template, currentKey)) {	  continue;	}	// Add the results to the stats accumulator	Object [] currentResult = (Object [])m_Results.elementAt(i);	numMatches++;	for (int j = 0; j < resultTypes.length; j++) {	  if (resultTypes[j] instanceof Double) {	    if (currentResult[j] == null) {	      // set the stats object for this result to null---	      // more than likely this is an additional measure field	      // not supported by the low level split evaluator	      if (stats[j] != null) {		stats[j] = null;	      }	      	      /* throw new Exception("Null numeric result field found:\n"		 + DatabaseUtils.arrayToString(currentKey)		 + " -- "		 + DatabaseUtils		 .arrayToString(currentResult)); */	    }	    if (stats[j] != null) {	      double currentVal = ((Double)currentResult[j]).doubleValue();	      stats[j].add(currentVal);	    }	  }	}      }      if (numMatches != m_ExpectedResultsPerAverage) {	throw new Exception("Expected " + m_ExpectedResultsPerAverage			    + " results matching key \""			    + DatabaseUtils.arrayToString(template)			    + "\" but got "			    + numMatches);      }      result[0] = new Double(numMatches);      Object [] currentResult = (Object [])m_Results.elementAt(0);      int k = 1;      for (int j = 0; j < resultTypes.length; j++) {	if (resultTypes[j] instanceof Double) {	  if (stats[j] != null) {	    stats[j].calculateDerived();	    result[k++] = new Double(stats[j].mean);	  } else {	    result[k++] = null;	  }	  if (getCalculateStdDevs()) {	    if (stats[j] != null) {	      result[k++] = new Double(stats[j].stdDev);	    } else {	      result[k++] = null;	    }	  }	} else {	  result[k++] = currentResult[j];	}      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -