📄 cvparameterselection.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * CVParameterSelection.java * Copyright (C) 1999 Len Trigg * */package weka.classifiers.meta;import weka.classifiers.Evaluation;import weka.classifiers.RandomizableSingleClassifierEnhancer;import weka.core.Capabilities;import weka.core.Drawable;import weka.core.FastVector;import weka.core.Instance;import weka.core.Instances;import weka.core.Option;import weka.core.OptionHandler;import weka.core.Summarizable;import weka.core.TechnicalInformation;import weka.core.TechnicalInformationHandler;import weka.core.Utils;import weka.core.TechnicalInformation.Field;import weka.core.TechnicalInformation.Type;import java.io.Serializable;import java.io.StreamTokenizer;import java.io.StringReader;import java.util.Enumeration;import java.util.Random;import java.util.Vector;/** <!-- globalinfo-start --> * Class for performing parameter selection by cross-validation for any classifier.<br/> * <br/> * For more information, see:<br/> * <br/> * R. Kohavi (1995). Wrappers for Performance Enhancement and Oblivious Decision Graphs. Department of Computer Science, Stanford University. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @phdthesis{Kohavi1995, * address = {Department of Computer Science, Stanford University}, * author = {R. Kohavi}, * school = {Stanford University}, * title = {Wrappers for Performance Enhancement and Oblivious Decision Graphs}, * year = {1995} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -X <number of folds> * Number of folds used for cross validation (default 10).</pre> * * <pre> -P <classifier parameter> * Classifier parameter options. * eg: "N 1 5 10" Sets an optimisation parameter for the * classifier with name -N, with lower bound 1, upper bound * 5, and 10 optimisation steps. The upper bound may be the * character 'A' or 'I' to substitute the number of * attributes or instances in the training data, * respectively. This parameter may be supplied more than * once to optimise over several classifier options * simultaneously.</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.rules.ZeroR)</pre> * * <pre> * Options specific to classifier weka.classifiers.rules.ZeroR: * </pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * Options after -- are passed to the designated sub-classifier. <p> * * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 1.33 $ */public class CVParameterSelection extends RandomizableSingleClassifierEnhancer implements Drawable, Summarizable, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = -6529603380876641265L; /** * A data structure to hold values associated with a single * cross-validation search parameter */ protected class CVParameter implements Serializable { /** for serialization */ static final long serialVersionUID = -4668812017709421953L; /** Char used to identify the option of interest */ private char m_ParamChar; /** Lower bound for the CV search */ private double m_Lower; /** Upper bound for the CV search */ private double m_Upper; /** Number of steps during the search */ private double m_Steps; /** The parameter value with the best performance */ private double m_ParamValue; /** True if the parameter should be added at the end of the argument list */ private boolean m_AddAtEnd; /** True if the parameter should be rounded to an integer */ private boolean m_RoundParam; /** * Constructs a CVParameter. * * @param param the parameter definition * @throws Exception if construction of CVParameter fails */ public CVParameter(String param) throws Exception { // Tokenize the string into it's parts StreamTokenizer st = new StreamTokenizer(new StringReader(param)); if (st.nextToken() != StreamTokenizer.TT_WORD) { throw new Exception("CVParameter " + param + ": Character parameter identifier expected"); } m_ParamChar = st.sval.charAt(0); if (st.nextToken() != StreamTokenizer.TT_NUMBER) { throw new Exception("CVParameter " + param + ": Numeric lower bound expected"); } m_Lower = st.nval; if (st.nextToken() == StreamTokenizer.TT_NUMBER) { m_Upper = st.nval; if (m_Upper < m_Lower) { throw new Exception("CVParameter " + param + ": Upper bound is less than lower bound"); } } else if (st.ttype == StreamTokenizer.TT_WORD) { if (st.sval.toUpperCase().charAt(0) == 'A') { m_Upper = m_Lower - 1; } else if (st.sval.toUpperCase().charAt(0) == 'I') { m_Upper = m_Lower - 2; } else { throw new Exception("CVParameter " + param + ": Upper bound must be numeric, or 'A' or 'N'"); } } else { throw new Exception("CVParameter " + param + ": Upper bound must be numeric, or 'A' or 'N'"); } if (st.nextToken() != StreamTokenizer.TT_NUMBER) { throw new Exception("CVParameter " + param + ": Numeric number of steps expected"); } m_Steps = st.nval; if (st.nextToken() == StreamTokenizer.TT_WORD) { if (st.sval.toUpperCase().charAt(0) == 'R') { m_RoundParam = true; } } } /** * Returns a CVParameter as a string. * * @return the CVParameter as string */ public String toString() { String result = m_ParamChar + " " + m_Lower + " "; switch ((int)(m_Lower - m_Upper + 0.5)) { case 1: result += "A"; break; case 2: result += "I"; break; default: result += m_Upper; break; } result += " " + m_Steps; if (m_RoundParam) { result += " R"; } return result; } } /** * The base classifier options (not including those being set * by cross-validation) */ protected String [] m_ClassifierOptions; /** The set of all classifier options as determined by cross-validation */ protected String [] m_BestClassifierOptions; /** The set of all options at initialization time. So that getOptions can return this. */ protected String [] m_InitOptions; /** The cross-validated performance of the best options */ protected double m_BestPerformance; /** The set of parameters to cross-validate over */ protected FastVector m_CVParams = new FastVector(); /** The number of attributes in the data */ protected int m_NumAttributes; /** The number of instances in a training fold */ protected int m_TrainFoldSize; /** The number of folds used in cross-validation */ protected int m_NumFolds = 10; /** * Create the options array to pass to the classifier. The parameter * values and positions are taken from m_ClassifierOptions and * m_CVParams. * * @return the options array */ protected String [] createOptions() { String [] options = new String [m_ClassifierOptions.length + 2 * m_CVParams.size()]; int start = 0, end = options.length; // Add the cross-validation parameters and their values for (int i = 0; i < m_CVParams.size(); i++) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i); double paramValue = cvParam.m_ParamValue; if (cvParam.m_RoundParam) { paramValue = (double)((int) (paramValue + 0.5)); } if (cvParam.m_AddAtEnd) { options[--end] = "" + Utils.doubleToString(paramValue,4); options[--end] = "-" + cvParam.m_ParamChar; } else { options[start++] = "-" + cvParam.m_ParamChar; options[start++] = "" + Utils.doubleToString(paramValue,4); } } // Add the static parameters System.arraycopy(m_ClassifierOptions, 0, options, start, m_ClassifierOptions.length); return options; } /** * Finds the best parameter combination. (recursive for each parameter * being optimised). * * @param depth the index of the parameter to be optimised at this level * @param trainData the data the search is based on * @param random a random number generator * @throws Exception if an error occurs */ protected void findParamsByCrossValidation(int depth, Instances trainData, Random random) throws Exception { if (depth < m_CVParams.size()) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(depth); double upper; switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: upper = m_NumAttributes; break; case 2: upper = m_TrainFoldSize; break; default: upper = cvParam.m_Upper; break; } double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1); for(cvParam.m_ParamValue = cvParam.m_Lower; cvParam.m_ParamValue <= upper; cvParam.m_ParamValue += increment) { findParamsByCrossValidation(depth + 1, trainData, random); } } else { Evaluation evaluation = new Evaluation(trainData); // Set the classifier options String [] options = createOptions(); if (m_Debug) { System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":"); for (int i = 0; i < options.length; i++) { System.err.print(" " + options[i]); } System.err.println(""); } ((OptionHandler)m_Classifier).setOptions(options); for (int j = 0; j < m_NumFolds; j++) { // We want to randomize the data the same way for every // learning scheme. Instances train = trainData.trainCV(m_NumFolds, j, new Random(1)); Instances test = trainData.testCV(m_NumFolds, j); m_Classifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(m_Classifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4)); } if ((m_BestPerformance == -99) || (error < m_BestPerformance)) { m_BestPerformance = error; m_BestClassifierOptions = createOptions(); } } } /** * Returns a string describing this classifier * @return a description of the classifier suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for performing parameter selection by cross-validation " + "for any classifier.\n\n" + "For more information, see:\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.PHDTHESIS); result.setValue(Field.AUTHOR, "R. Kohavi"); result.setValue(Field.YEAR, "1995"); result.setValue(Field.TITLE, "Wrappers for Performance Enhancement and Oblivious Decision Graphs"); result.setValue(Field.SCHOOL, "Stanford University"); result.setValue(Field.ADDRESS, "Department of Computer Science, Stanford University"); return result; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(2); newVector.addElement(new Option( "\tNumber of folds used for cross validation (default 10).",
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -