gridsearch.java
来自「Weka」· Java 代码 · 共 2,258 行 · 第 1/5 页
JAVA
2,258 行
*/ public void add(int cv, Performance p) { m_Cache.put(getID(cv, p.getValues()), p); } /** * returns a string representation of the cache * * @return the string representation of the cache */ public String toString() { return m_Cache.toString(); } } /** for serialization */ private static final long serialVersionUID = -3034773968581595348L; /** evaluation via: Correlation coefficient */ public static final int EVALUATION_CC = 0; /** evaluation via: Root mean squared error */ public static final int EVALUATION_RMSE = 1; /** evaluation via: Root relative squared error */ public static final int EVALUATION_RRSE = 2; /** evaluation via: Mean absolute error */ public static final int EVALUATION_MAE = 3; /** evaluation via: Relative absolute error */ public static final int EVALUATION_RAE = 4; /** evaluation via: Combined = (1-CC) + RRSE + RAE */ public static final int EVALUATION_COMBINED = 5; /** evaluation via: Accuracy */ public static final int EVALUATION_ACC = 6; /** evaluation */ public static final Tag[] TAGS_EVALUATION = { new Tag(EVALUATION_CC, "CC", "Correlation coefficient"), new Tag(EVALUATION_RMSE, "RMSE", "Root mean squared error"), new Tag(EVALUATION_RRSE, "RRSE", "Root relative squared error"), new Tag(EVALUATION_MAE, "MAE", "Mean absolute error"), new Tag(EVALUATION_RAE, "RAE", "Root absolute error"), new Tag(EVALUATION_COMBINED, "COMB", "Combined = (1-abs(CC)) + RRSE + RAE"), new Tag(EVALUATION_ACC, "ACC", "Accuracy") }; /** row-wise grid traversal */ public static final int TRAVERSAL_BY_ROW = 0; /** column-wise grid traversal */ public static final int TRAVERSAL_BY_COLUMN = 1; /** traversal */ public static final Tag[] TAGS_TRAVERSAL = { new Tag(TRAVERSAL_BY_ROW, "row-wise", "row-wise"), new Tag(TRAVERSAL_BY_COLUMN, "column-wise", "column-wise") }; /** the prefix to indicate that the option is for the classifier */ public final static String PREFIX_CLASSIFIER = "classifier."; /** the prefix to indicate that the option is for the filter */ public final static String PREFIX_FILTER = "filter."; /** the Filter */ protected Filter m_Filter; /** the Filter with the best setup */ protected Filter m_BestFilter; /** the Classifier with the best setup */ protected Classifier m_BestClassifier; /** the best values */ protected PointDouble m_Values = null; /** the type of evaluation */ protected int m_Evaluation = EVALUATION_CC; /** the Y option to work on (without leading dash, preceding 'classifier.' * means to set the option for the classifier 'filter.' for the filter) */ protected String m_Y_Property = PREFIX_CLASSIFIER + "ridge"; /** the minimum of Y */ protected double m_Y_Min = -10; /** the maximum of Y */ protected double m_Y_Max = +5; /** the step size of Y */ protected double m_Y_Step = 1; /** the base for Y */ protected double m_Y_Base = 10; /** * The expression for the Y property. Available parameters for the * expression: * <ul> * <li>BASE</li> * <li>FROM (= min)</li> * <li>TO (= max)</li> * <li>STEP</li> * <li>I - the current value (from 'from' to 'to' with stepsize 'step')</li> * </ul> * * @see MathematicalExpression * @see MathExpression */ protected String m_Y_Expression = "pow(BASE,I)"; /** the X option to work on (without leading dash, preceding 'classifier.' * means to set the option for the classifier 'filter.' for the filter) */ protected String m_X_Property = PREFIX_FILTER + "numComponents"; /** the minimum of X */ protected double m_X_Min = +5; /** the maximum of X */ protected double m_X_Max = +20; /** the step size of */ protected double m_X_Step = 1; /** the base for */ protected double m_X_Base = 10; /** * The expression for the X property. Available parameters for the * expression: * <ul> * <li>BASE</li> * <li>FROM (= min)</li> * <li>TO (= max)</li> * <li>STEP</li> * <li>I - the current value (from 'from' to 'to' with stepsize 'step')</li> * </ul> * * @see MathematicalExpression * @see MathExpression */ protected String m_X_Expression = "I"; /** whether the grid can be extended */ protected boolean m_GridIsExtendable = false; /** maximum number of grid extensions (-1 means unlimited) */ protected int m_MaxGridExtensions = 3; /** the number of extensions performed */ protected int m_GridExtensionsPerformed = 0; /** the sample size to search the initial grid with */ protected double m_SampleSize = 100; /** the traversal */ protected int m_Traversal = TRAVERSAL_BY_COLUMN; /** the log file to use */ protected File m_LogFile = new File(System.getProperty("user.dir")); /** the value-pairs grid */ protected Grid m_Grid; /** the training data */ protected Instances m_Data; /** the cache for points in the grid that got calculated */ protected PerformanceCache m_Cache; /** whether all performances in the grid are the same */ protected boolean m_UniformPerformance = false; /** * the default constructor */ public GridSearch() { super(); // classifier m_Classifier = new LinearRegression(); ((LinearRegression) m_Classifier).setAttributeSelectionMethod(new SelectedTag(LinearRegression.SELECTION_NONE, LinearRegression.TAGS_SELECTION)); ((LinearRegression) m_Classifier).setEliminateColinearAttributes(false); // filter m_Filter = new PLSFilter(); PLSFilter filter = new PLSFilter(); filter.setPreprocessing(new SelectedTag(PLSFilter.PREPROCESSING_STANDARDIZE, PLSFilter.TAGS_PREPROCESSING)); filter.setReplaceMissing(true); try { m_BestClassifier = Classifier.makeCopy(m_Classifier); } catch (Exception e) { e.printStackTrace(); } try { m_BestFilter = Filter.makeCopy(filter); } catch (Exception e) { e.printStackTrace(); } } /** * Returns a string describing classifier * * @return a description suitable for displaying in the * explorer/experimenter gui */ public String globalInfo() { return "Performs a grid search of parameter pairs for the a classifier " + "(Y-axis, default is LinearRegression with the \"Ridge\" parameter) " + "and the PLSFilter (X-axis, \"# of Components\") and chooses the best " + "pair found for the actual predicting.\n\n" + "The initial grid is worked on with 2-fold CV to determine the values " + "of the parameter pairs for the selected type of evaluation (e.g., " + "accuracy). The best point in the grid is then taken and a 10-fold CV " + "is performed with the adjacent parameter pairs. If a better pair is " + "found, then this will act as new center and another 10-fold CV will " + "be performed (kind of hill-climbing). This process is repeated until " + "no better pair is found or the best pair is on the border of the grid.\n" + "In case the best pair is on the border, one can let GridSearch " + "automatically extend the grid and continue the search. Check out the " + "properties 'gridIsExtendable' (option '-extend-grid') and " + "'maxGridExtensions' (option '-max-grid-extensions <num>').\n\n" + "GridSearch can handle doubles, integers (values are just cast to int) " + "and booleans (0 is false, otherwise true). float, char and long are " + "supported as well.\n\n" + "The best filter/classifier setup can be accessed after the buildClassifier " + "call via the getBestFilter/getBestClassifier methods.\n" + "Note on the implementation: after the data has been passed through " + "the filter, a default NumericCleaner filter is applied to the data in " + "order to avoid numbers that are getting too small and might produce " + "NaNs in other schemes."; } /** * String describing default classifier. * * @return the classname of the default classifier */ protected String defaultClassifierString() { return LinearRegression.class.getName(); } /** * Gets an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions(){ Vector result; Enumeration en; String desc; SelectedTag tag; int i; result = new Vector(); desc = ""; for (i = 0; i < TAGS_EVALUATION.length; i++) { tag = new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION); desc += "\t" + tag.getSelectedTag().getIDStr() + " = " + tag.getSelectedTag().getReadable() + "\n"; } result.addElement(new Option( "\tDetermines the parameter used for evaluation:\n" + desc + "\t(default: " + new SelectedTag(EVALUATION_CC, TAGS_EVALUATION) + ")", "E", 1, "-E " + Tag.toOptionList(TAGS_EVALUATION))); result.addElement(new Option( "\tThe Y option to test (without leading dash).\n" + "\t(default: " + PREFIX_CLASSIFIER + "ridge)", "y-property", 1, "-y-property <option>")); result.addElement(new Option( "\tThe minimum for Y.\n" + "\t(default: -10)", "y-min", 1, "-y-min <num>")); result.addElement(new Option( "\tThe maximum for Y.\n" + "\t(default: +5)", "y-max", 1, "-y-max <num>")); result.addElement(new Option( "\tThe step size for Y.\n" + "\t(default: 1)", "y-step", 1, "-y-step <num>")); result.addElement(new Option( "\tThe base for Y.\n" + "\t(default: 10)", "y-base", 1, "-y-base <num>")); result.addElement(new Option( "\tThe expression for Y.\n" + "\tAvailable parameters:\n" + "\t\tBASE\n" + "\t\tFROM\n" + "\t\tTO\n" + "\t\tSTEP\n" + "\t\tI - the current iteration value\n" + "\t\t(from 'FROM' to 'TO' with stepsize 'STEP')\n" + "\t(default: 'pow(BASE,I)')", "y-expression", 1, "-y-expression <expr>")); result.addElement(new Option( "\tThe filter to use (on X axis). Full classname of filter to include, \n" + "\tfollowed by scheme options.\n" + "\t(default: weka.filters.supervised.attribute.PLSFilter)", "filter", 1, "-filter <filter specification>")); result.addElement(new Option( "\tThe X option to test (without leading dash).\n" + "\t(default: " + PREFIX_FILTER + "numComponents)", "x-property", 1, "-x-property <option>")); result.addElement(new Option( "\tThe minimum for X.\n" + "\t(default: +5)", "x-min", 1, "-x-min <num>")); result.addElement(new Option( "\tThe maximum for X.\n" + "\t(default: +20)", "x-max", 1, "-x-max <num>")); result.addElement(new Option( "\tThe step size for X.\n" + "\t(default: 1)", "x-step", 1, "-x-step <num>")); result.addElement(new Option( "\tThe base for X.\n" + "\t(default: 10)", "x-base", 1, "-x-base <num>")); result.addElement(new Option( "\tThe expression for the X value.\n" + "\tAvailable parameters:\n" + "\t\tBASE\n" + "\t\tMIN\n" + "\t\tMAX\n" + "\t\tSTEP\n" + "\t\tI - the current iteration value\n" + "\t\t(from 'FROM' to 'TO' with stepsize 'STEP')\n" + "\t(default: 'pow(BASE,I)')", "x-expression", 1, "-x-expression <expr>")); result.addElement(new Option( "\tWhether the grid can be extended.\n" + "\t(default: no)", "extend-grid", 0, "-extend-grid")); result.addElement(new Option( "\tThe maximum number of grid extensions (-1 is unlimited).\n" + "\t(default: 3)", "max-grid-extensions", 1, "-max-grid-extensions <num>")); result.addElement(new Option( "\tThe size (in percent) of the sample to search the inital grid with.\n" + "\t(default: 100)", "sample-size", 1, "-sample-size <num>")); result.addElement(new Option( "\tThe type of traversal for the grid.\n" + "\t(default: " + new SelectedTag(TRAVERSAL_BY_COLUMN, TAGS_TRAVERSAL) + ")", "traversal", 1, "-traversal " + Tag.toOptionList(TAGS_TRAVERSAL))); result.addElement(new Option( "\tThe log file to log the messages to.\n" + "\t(default: none)", "log-file", 1, "-log-file <filename>")); en = super.listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); if (getFilter() instanceof OptionHandler) { result.addElement(new Option( "", "", 0, "\nOptions specific to filter " + getFilter().getClass().getName() + " ('-filter'):")); en = ((OptionHandler) getFilter()).listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); } return result.elements(); } /** * returns the options of the current setup * * @return the current options */ public String[] getOptions(){ int i; Vector result; String[] options; result = new Vector(); result.add("-E"); result.add("" + getEvaluation()); result.add("-y-property"); result.add("" + getYProperty()); result.add("-y-min"); result.add("" + getYMin()); result.add("-y-max"); result.add("" + getYMax()); result.add("-y-step"); result.add("" + getYStep()); result.add("-y-base"); result.add("" + getYBase()); result.add("-y-expression"); result.add("" + getYExpression()); result.add("-filter"); if (getFilter() instanceof OptionHandler) result.add( getFilter().getClass().getName() + " " + Utils.joinOptions(((OptionHandler) getFilter()).getOptions())); else result.add( getFilter().getClass().getName()); result.add("-x-property"); result.add("" + getXProperty()); result.add("-x-min"); result.add("" + getXMin()); result.add("-x-max"); result.add("" + getXMax()); result.add("-x-step"); result.add("" + getXStep()); result.add("-x-base"); result.add("" + getXBase()); result.add("-x-expression");
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?