gridsearch.java

来自「Weka」· Java 代码 · 共 2,258 行 · 第 1/5 页

JAVA
2,258
字号
     */    public void add(int cv, Performance p) {      m_Cache.put(getID(cv, p.getValues()), p);    }        /**     * returns a string representation of the cache     *      * @return		the string representation of the cache     */    public String toString() {      return m_Cache.toString();    }  }    /** for serialization */  private static final long serialVersionUID = -3034773968581595348L;  /** evaluation via: Correlation coefficient */  public static final int EVALUATION_CC = 0;  /** evaluation via: Root mean squared error */  public static final int EVALUATION_RMSE = 1;  /** evaluation via: Root relative squared error */  public static final int EVALUATION_RRSE = 2;  /** evaluation via: Mean absolute error */  public static final int EVALUATION_MAE = 3;  /** evaluation via: Relative absolute error */  public static final int EVALUATION_RAE = 4;  /** evaluation via: Combined = (1-CC) + RRSE + RAE */  public static final int EVALUATION_COMBINED = 5;  /** evaluation via: Accuracy */  public static final int EVALUATION_ACC = 6;  /** evaluation */  public static final Tag[] TAGS_EVALUATION = {    new Tag(EVALUATION_CC, "CC", "Correlation coefficient"),    new Tag(EVALUATION_RMSE, "RMSE", "Root mean squared error"),    new Tag(EVALUATION_RRSE, "RRSE", "Root relative squared error"),    new Tag(EVALUATION_MAE, "MAE", "Mean absolute error"),    new Tag(EVALUATION_RAE, "RAE", "Root absolute error"),    new Tag(EVALUATION_COMBINED, "COMB", "Combined = (1-abs(CC)) + RRSE + RAE"),    new Tag(EVALUATION_ACC, "ACC", "Accuracy")  };    /** row-wise grid traversal */  public static final int TRAVERSAL_BY_ROW = 0;  /** column-wise grid traversal */  public static final int TRAVERSAL_BY_COLUMN = 1;  /** traversal */  public static final Tag[] TAGS_TRAVERSAL = {    new Tag(TRAVERSAL_BY_ROW, "row-wise", "row-wise"),    new Tag(TRAVERSAL_BY_COLUMN, "column-wise", "column-wise")  };  /** the prefix to indicate that the option is for the classifier */  public final static String PREFIX_CLASSIFIER = "classifier.";  /** the prefix to indicate that the option is for the filter */  public final static String PREFIX_FILTER = "filter.";    /** the Filter */  protected Filter m_Filter;    /** the Filter with the best setup */  protected Filter m_BestFilter;    /** the Classifier with the best setup */  protected Classifier m_BestClassifier;  /** the best values */  protected PointDouble m_Values = null;    /** the type of evaluation */  protected int m_Evaluation = EVALUATION_CC;  /** the Y option to work on (without leading dash, preceding 'classifier.'    * means to set the option for the classifier 'filter.' for the filter) */  protected String m_Y_Property = PREFIX_CLASSIFIER + "ridge";    /** the minimum of Y */  protected double m_Y_Min = -10;    /** the maximum of Y */  protected double m_Y_Max = +5;    /** the step size of Y */  protected double m_Y_Step = 1;    /** the base for Y */  protected double m_Y_Base = 10;    /**    * The expression for the Y property. Available parameters for the   * expression:   * <ul>   *   <li>BASE</li>   *   <li>FROM (= min)</li>   *   <li>TO (= max)</li>   *   <li>STEP</li>   *   <li>I - the current value (from 'from' to 'to' with stepsize 'step')</li>   * </ul>   *    * @see MathematicalExpression   * @see MathExpression   */  protected String m_Y_Expression = "pow(BASE,I)";  /** the X option to work on (without leading dash, preceding 'classifier.'    * means to set the option for the classifier 'filter.' for the filter) */  protected String m_X_Property = PREFIX_FILTER + "numComponents";    /** the minimum of X */  protected double m_X_Min = +5;    /** the maximum of X */  protected double m_X_Max = +20;    /** the step size of  */  protected double m_X_Step = 1;    /** the base for  */  protected double m_X_Base = 10;    /**    * The expression for the X property. Available parameters for the   * expression:   * <ul>   *   <li>BASE</li>   *   <li>FROM (= min)</li>   *   <li>TO (= max)</li>   *   <li>STEP</li>   *   <li>I - the current value (from 'from' to 'to' with stepsize 'step')</li>   * </ul>   *    * @see MathematicalExpression   * @see MathExpression   */  protected String m_X_Expression = "I";  /** whether the grid can be extended */  protected boolean m_GridIsExtendable = false;    /** maximum number of grid extensions (-1 means unlimited) */  protected int m_MaxGridExtensions = 3;    /** the number of extensions performed */  protected int m_GridExtensionsPerformed = 0;  /** the sample size to search the initial grid with */  protected double m_SampleSize = 100;    /** the traversal */  protected int m_Traversal = TRAVERSAL_BY_COLUMN;  /** the log file to use */  protected File m_LogFile = new File(System.getProperty("user.dir"));    /** the value-pairs grid */  protected Grid m_Grid;  /** the training data */  protected Instances m_Data;  /** the cache for points in the grid that got calculated */  protected PerformanceCache m_Cache;  /** whether all performances in the grid are the same */  protected boolean m_UniformPerformance = false;    /**   * the default constructor   */  public GridSearch() {    super();        // classifier    m_Classifier = new LinearRegression();    ((LinearRegression) m_Classifier).setAttributeSelectionMethod(new SelectedTag(LinearRegression.SELECTION_NONE, LinearRegression.TAGS_SELECTION));    ((LinearRegression) m_Classifier).setEliminateColinearAttributes(false);        // filter    m_Filter = new PLSFilter();    PLSFilter filter = new PLSFilter();    filter.setPreprocessing(new SelectedTag(PLSFilter.PREPROCESSING_STANDARDIZE, PLSFilter.TAGS_PREPROCESSING));    filter.setReplaceMissing(true);        try {      m_BestClassifier = Classifier.makeCopy(m_Classifier);    }    catch (Exception e) {      e.printStackTrace();    }    try {      m_BestFilter = Filter.makeCopy(filter);    }    catch (Exception e) {      e.printStackTrace();    }  }    /**   * Returns a string describing classifier   *    * @return a description suitable for displaying in the   *         explorer/experimenter gui   */  public String globalInfo() {    return         "Performs a grid search of parameter pairs for the a classifier "      + "(Y-axis, default is LinearRegression with the \"Ridge\" parameter) "      + "and the PLSFilter (X-axis, \"# of Components\") and chooses the best "      + "pair found for the actual predicting.\n\n"      + "The initial grid is worked on with 2-fold CV to determine the values "      + "of the parameter pairs for the selected type of evaluation (e.g., "      + "accuracy). The best point in the grid is then taken and a 10-fold CV "      + "is performed with the adjacent parameter pairs. If a better pair is "      + "found, then this will act as new center and another 10-fold CV will "      + "be performed (kind of hill-climbing). This process is repeated until "      + "no better pair is found or the best pair is on the border of the grid.\n"      + "In case the best pair is on the border, one can let GridSearch "      + "automatically extend the grid and continue the search. Check out the "      + "properties 'gridIsExtendable' (option '-extend-grid') and "      + "'maxGridExtensions' (option '-max-grid-extensions <num>').\n\n"      + "GridSearch can handle doubles, integers (values are just cast to int) "      + "and booleans (0 is false, otherwise true). float, char and long are "      + "supported as well.\n\n"      + "The best filter/classifier setup can be accessed after the buildClassifier "      + "call via the getBestFilter/getBestClassifier methods.\n"      + "Note on the implementation: after the data has been passed through "      + "the filter, a default NumericCleaner filter is applied to the data in "      + "order to avoid numbers that are getting too small and might produce "      + "NaNs in other schemes.";  }  /**   * String describing default classifier.   *    * @return		the classname of the default classifier   */  protected String defaultClassifierString() {    return LinearRegression.class.getName();  }  /**   * Gets an enumeration describing the available options.   *   * @return an enumeration of all the available options.   */  public Enumeration listOptions(){    Vector        	result;    Enumeration   	en;    String		desc;    SelectedTag		tag;    int			i;    result = new Vector();    desc  = "";    for (i = 0; i < TAGS_EVALUATION.length; i++) {      tag = new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION);      desc  +=   "\t" + tag.getSelectedTag().getIDStr()       	       + " = " + tag.getSelectedTag().getReadable()      	       + "\n";    }    result.addElement(new Option(	"\tDetermines the parameter used for evaluation:\n"	+ desc	+ "\t(default: " + new SelectedTag(EVALUATION_CC, TAGS_EVALUATION) + ")",	"E", 1, "-E " + Tag.toOptionList(TAGS_EVALUATION)));    result.addElement(new Option(	"\tThe Y option to test (without leading dash).\n"	+ "\t(default: " + PREFIX_CLASSIFIER + "ridge)",	"y-property", 1, "-y-property <option>"));    result.addElement(new Option(	"\tThe minimum for Y.\n"	+ "\t(default: -10)",	"y-min", 1, "-y-min <num>"));    result.addElement(new Option(	"\tThe maximum for Y.\n"	+ "\t(default: +5)",	"y-max", 1, "-y-max <num>"));    result.addElement(new Option(	"\tThe step size for Y.\n"	+ "\t(default: 1)",	"y-step", 1, "-y-step <num>"));    result.addElement(new Option(	"\tThe base for Y.\n"	+ "\t(default: 10)",	"y-base", 1, "-y-base <num>"));    result.addElement(new Option(	"\tThe expression for Y.\n"	+ "\tAvailable parameters:\n"	+ "\t\tBASE\n"	+ "\t\tFROM\n"	+ "\t\tTO\n"	+ "\t\tSTEP\n"	+ "\t\tI - the current iteration value\n"	+ "\t\t(from 'FROM' to 'TO' with stepsize 'STEP')\n"	+ "\t(default: 'pow(BASE,I)')",	"y-expression", 1, "-y-expression <expr>"));    result.addElement(new Option(	"\tThe filter to use (on X axis). Full classname of filter to include, \n"	+ "\tfollowed by scheme options.\n"	+ "\t(default: weka.filters.supervised.attribute.PLSFilter)",	"filter", 1, "-filter <filter specification>"));    result.addElement(new Option(	"\tThe X option to test (without leading dash).\n"	+ "\t(default: " + PREFIX_FILTER + "numComponents)",	"x-property", 1, "-x-property <option>"));    result.addElement(new Option(	"\tThe minimum for X.\n"	+ "\t(default: +5)",	"x-min", 1, "-x-min <num>"));    result.addElement(new Option(	"\tThe maximum for X.\n"	+ "\t(default: +20)",	"x-max", 1, "-x-max <num>"));    result.addElement(new Option(	"\tThe step size for X.\n"	+ "\t(default: 1)",	"x-step", 1, "-x-step <num>"));    result.addElement(new Option(	"\tThe base for X.\n"	+ "\t(default: 10)",	"x-base", 1, "-x-base <num>"));    result.addElement(new Option(	"\tThe expression for the X value.\n"	+ "\tAvailable parameters:\n"	+ "\t\tBASE\n"	+ "\t\tMIN\n"	+ "\t\tMAX\n"	+ "\t\tSTEP\n"	+ "\t\tI - the current iteration value\n"	+ "\t\t(from 'FROM' to 'TO' with stepsize 'STEP')\n"	+ "\t(default: 'pow(BASE,I)')",	"x-expression", 1, "-x-expression <expr>"));    result.addElement(new Option(	"\tWhether the grid can be extended.\n"	+ "\t(default: no)",	"extend-grid", 0, "-extend-grid"));    result.addElement(new Option(	"\tThe maximum number of grid extensions (-1 is unlimited).\n"	+ "\t(default: 3)",	"max-grid-extensions", 1, "-max-grid-extensions <num>"));    result.addElement(new Option(	"\tThe size (in percent) of the sample to search the inital grid with.\n"	+ "\t(default: 100)",	"sample-size", 1, "-sample-size <num>"));    result.addElement(new Option(	"\tThe type of traversal for the grid.\n"	+ "\t(default: " + new SelectedTag(TRAVERSAL_BY_COLUMN, TAGS_TRAVERSAL) + ")",	"traversal", 1, "-traversal " + Tag.toOptionList(TAGS_TRAVERSAL)));    result.addElement(new Option(	"\tThe log file to log the messages to.\n"	+ "\t(default: none)",	"log-file", 1, "-log-file <filename>"));    en = super.listOptions();    while (en.hasMoreElements())      result.addElement(en.nextElement());    if (getFilter() instanceof OptionHandler) {      result.addElement(new Option(	  "",	  "", 0, "\nOptions specific to filter "	  + getFilter().getClass().getName() + " ('-filter'):"));            en = ((OptionHandler) getFilter()).listOptions();      while (en.hasMoreElements())	result.addElement(en.nextElement());    }    return result.elements();  }    /**   * returns the options of the current setup   *   * @return		the current options   */  public String[] getOptions(){    int       	i;    Vector    	result;    String[]  	options;    result = new Vector();    result.add("-E");    result.add("" + getEvaluation());    result.add("-y-property");    result.add("" + getYProperty());    result.add("-y-min");    result.add("" + getYMin());    result.add("-y-max");    result.add("" + getYMax());    result.add("-y-step");    result.add("" + getYStep());    result.add("-y-base");    result.add("" + getYBase());    result.add("-y-expression");    result.add("" + getYExpression());    result.add("-filter");    if (getFilter() instanceof OptionHandler)      result.add(  	    getFilter().getClass().getName() 	  + " " 	  + Utils.joinOptions(((OptionHandler) getFilter()).getOptions()));    else      result.add(	  getFilter().getClass().getName());    result.add("-x-property");    result.add("" + getXProperty());    result.add("-x-min");    result.add("" + getXMin());    result.add("-x-max");    result.add("" + getXMax());    result.add("-x-step");    result.add("" + getXStep());    result.add("-x-base");    result.add("" + getXBase());    result.add("-x-expression");

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?