⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cvparameterselection.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    CVParameterSelection.java *    Copyright (C) 1999 Len Trigg * */package weka.classifiers.meta;import weka.classifiers.Evaluation;import weka.classifiers.RandomizableSingleClassifierEnhancer;import weka.core.Capabilities;import weka.core.Drawable;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.Summarizable;import weka.core.TechnicalInformation;import weka.core.TechnicalInformationHandler;import weka.core.Utils;import weka.core.TechnicalInformation.Field;import weka.core.TechnicalInformation.Type;import java.io.Serializable;import java.io.StreamTokenizer;import java.io.StringReader;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Class for performing parameter selection by cross-validation for any classifier.<br/> * <br/> * For more information, see:<br/> * <br/> * R. Kohavi (1995). Wrappers for Performance Enhancement and Oblivious Decision Graphs. Department of Computer Science, Stanford University. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * &#64;phdthesis{Kohavi1995, *    address = {Department of Computer Science, Stanford University}, *    author = {R. Kohavi}, *    school = {Stanford University}, *    title = {Wrappers for Performance Enhancement and Oblivious Decision Graphs}, *    year = {1995} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> *  * <pre> -X &lt;number of folds&gt; *  Number of folds used for cross validation (default 10).</pre> *  * <pre> -P &lt;classifier parameter&gt; *  Classifier parameter options. *  eg: "N 1 5 10" Sets an optimisation parameter for the *  classifier with name -N, with lower bound 1, upper bound *  5, and 10 optimisation steps. The upper bound may be the *  character 'A' or 'I' to substitute the number of *  attributes or instances in the training data, *  respectively. This parameter may be supplied more than *  once to optimise over several classifier options *  simultaneously.</pre> *  * <pre> -S &lt;num&gt; *  Random number seed. *  (default 1)</pre> *  * <pre> -D *  If set, classifier is run in debug mode and *  may output additional info to the console</pre> *  * <pre> -W *  Full name of base classifier. *  (default: weka.classifiers.rules.ZeroR)</pre> *  * <pre>  * Options specific to classifier weka.classifiers.rules.ZeroR: * </pre> *  * <pre> -D *  If set, classifier is run in debug mode and *  may output additional info to the console</pre> *  <!-- options-end --> * * Options after -- are passed to the designated sub-classifier. <p> * * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 1.33 $ */public class CVParameterSelection   extends RandomizableSingleClassifierEnhancer  implements Drawable, Summarizable, TechnicalInformationHandler {  /** for serialization */  static final long serialVersionUID = -6529603380876641265L;    /**   * A data structure to hold values associated with a single   * cross-validation search parameter   */  protected class CVParameter     implements Serializable {        /** for serialization */    static final long serialVersionUID = -4668812017709421953L;    /**  Char used to identify the option of interest */    private char m_ParamChar;        /**  Lower bound for the CV search */    private double m_Lower;          /**  Upper bound for the CV search */    private double m_Upper;          /**  Number of steps during the search */    private double m_Steps;          /**  The parameter value with the best performance */    private double m_ParamValue;     /**  True if the parameter should be added at the end of the argument list */    private boolean m_AddAtEnd;      /**  True if the parameter should be rounded to an integer */    private boolean m_RoundParam;    /**     * Constructs a CVParameter.     *      * @param param the parameter definition     * @throws Exception if construction of CVParameter fails     */    public CVParameter(String param) throws Exception {           // Tokenize the string into it's parts      StreamTokenizer st = new StreamTokenizer(new StringReader(param));      if (st.nextToken() != StreamTokenizer.TT_WORD) {	throw new Exception("CVParameter " + param 			    + ": Character parameter identifier expected");      }      m_ParamChar = st.sval.charAt(0);      if (st.nextToken() != StreamTokenizer.TT_NUMBER) {	throw new Exception("CVParameter " + param 			    + ": Numeric lower bound expected");      }      m_Lower = st.nval;      if (st.nextToken() == StreamTokenizer.TT_NUMBER) {	m_Upper = st.nval;	if (m_Upper < m_Lower) {	  throw new Exception("CVParameter " + param			      + ": Upper bound is less than lower bound");	}      } else if (st.ttype == StreamTokenizer.TT_WORD) {	if (st.sval.toUpperCase().charAt(0) == 'A') {	  m_Upper = m_Lower - 1;	} else if (st.sval.toUpperCase().charAt(0) == 'I') {	  m_Upper = m_Lower - 2;	} else {	  throw new Exception("CVParameter " + param 	      + ": Upper bound must be numeric, or 'A' or 'N'");	}      } else {	throw new Exception("CVParameter " + param 	      + ": Upper bound must be numeric, or 'A' or 'N'");      }      if (st.nextToken() != StreamTokenizer.TT_NUMBER) {	throw new Exception("CVParameter " + param 			    + ": Numeric number of steps expected");      }      m_Steps = st.nval;      if (st.nextToken() == StreamTokenizer.TT_WORD) {	if (st.sval.toUpperCase().charAt(0) == 'R') {	  m_RoundParam = true;	}      }    }    /**     * Returns a CVParameter as a string.     *      * @return the CVParameter as string     */    public String toString() {      String result = m_ParamChar + " " + m_Lower + " ";      switch ((int)(m_Lower - m_Upper + 0.5)) {      case 1:	result += "A";	break;      case 2:	result += "I";	break;      default:	result += m_Upper;	break;      }      result += " " + m_Steps;      if (m_RoundParam) {	result += " R";      }      return result;    }  }  /**   * The base classifier options (not including those being set   * by cross-validation)   */  protected String [] m_ClassifierOptions;  /** The set of all classifier options as determined by cross-validation */  protected String [] m_BestClassifierOptions;  /** The set of all options at initialization time. So that getOptions      can return this. */  protected String [] m_InitOptions;  /** The cross-validated performance of the best options */  protected double m_BestPerformance;  /** The set of parameters to cross-validate over */  protected FastVector m_CVParams = new FastVector();  /** The number of attributes in the data */  protected int m_NumAttributes;  /** The number of instances in a training fold */  protected int m_TrainFoldSize;    /** The number of folds used in cross-validation */  protected int m_NumFolds = 10;  /**   * Create the options array to pass to the classifier. The parameter   * values and positions are taken from m_ClassifierOptions and   * m_CVParams.   *   * @return the options array   */  protected String [] createOptions() {        String [] options = new String [m_ClassifierOptions.length 				   + 2 * m_CVParams.size()];    int start = 0, end = options.length;    // Add the cross-validation parameters and their values    for (int i = 0; i < m_CVParams.size(); i++) {      CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);      double paramValue = cvParam.m_ParamValue;      if (cvParam.m_RoundParam) {	paramValue = (double)((int) (paramValue + 0.5));      }      if (cvParam.m_AddAtEnd) {	options[--end] = "" + 	Utils.doubleToString(paramValue,4);	options[--end] = "-" + cvParam.m_ParamChar;      } else {	options[start++] = "-" + cvParam.m_ParamChar;	options[start++] = "" 	+ Utils.doubleToString(paramValue,4);      }    }    // Add the static parameters    System.arraycopy(m_ClassifierOptions, 0,		     options, start,		     m_ClassifierOptions.length);    return options;  }  /**   * Finds the best parameter combination. (recursive for each parameter   * being optimised).   *    * @param depth the index of the parameter to be optimised at this level   * @param trainData the data the search is based on   * @param random a random number generator   * @throws Exception if an error occurs   */  protected void findParamsByCrossValidation(int depth, Instances trainData,					     Random random)    throws Exception {    if (depth < m_CVParams.size()) {      CVParameter cvParam = (CVParameter)m_CVParams.elementAt(depth);      double upper;      switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {      case 1:	upper = m_NumAttributes;	break;      case 2:	upper = m_TrainFoldSize;	break;      default:	upper = cvParam.m_Upper;	break;      }      double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1);      for(cvParam.m_ParamValue = cvParam.m_Lower; 	  cvParam.m_ParamValue <= upper; 	  cvParam.m_ParamValue += increment) {	findParamsByCrossValidation(depth + 1, trainData, random);      }    } else {            Evaluation evaluation = new Evaluation(trainData);      // Set the classifier options      String [] options = createOptions();      if (m_Debug) {	System.err.print("Setting options for " 			 + m_Classifier.getClass().getName() + ":");	for (int i = 0; i < options.length; i++) {	  System.err.print(" " + options[i]);	}	System.err.println("");      }      ((OptionHandler)m_Classifier).setOptions(options);      for (int j = 0; j < m_NumFolds; j++) {        // We want to randomize the data the same way for every         // learning scheme.	Instances train = trainData.trainCV(m_NumFolds, j, new Random(1));	Instances test = trainData.testCV(m_NumFolds, j);	m_Classifier.buildClassifier(train);	evaluation.setPriors(train);	evaluation.evaluateModel(m_Classifier, test);      }      double error = evaluation.errorRate();      if (m_Debug) {	System.err.println("Cross-validated error rate: " 			   + Utils.doubleToString(error, 6, 4));      }      if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {		m_BestPerformance = error;	m_BestClassifierOptions = createOptions();      }    }  }  /**   * Returns a string describing this classifier   * @return a description of the classifier suitable for   * displaying in the explorer/experimenter gui   */  public String globalInfo() {    return    "Class for performing parameter selection by cross-validation "	    + "for any classifier.\n\n"            + "For more information, see:\n\n"            + getTechnicalInformation().toString();  }  /**   * Returns an instance of a TechnicalInformation object, containing    * detailed information about the technical background of this class,   * e.g., paper reference or book this class is based on.   *    * @return the technical information about this class   */  public TechnicalInformation getTechnicalInformation() {    TechnicalInformation 	result;        result = new TechnicalInformation(Type.PHDTHESIS);    result.setValue(Field.AUTHOR, "R. Kohavi");    result.setValue(Field.YEAR, "1995");    result.setValue(Field.TITLE, "Wrappers for Performance Enhancement and Oblivious Decision Graphs");    result.setValue(Field.SCHOOL, "Stanford University");    result.setValue(Field.ADDRESS, "Department of Computer Science, Stanford University");        return result;  }  /**   * Returns an enumeration describing the available options.   *   * @return an enumeration of all the available options.   */  public Enumeration listOptions() {    Vector newVector = new Vector(2);    newVector.addElement(new Option(	      "\tNumber of folds used for cross validation (default 10).",

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -