📄 ensembleselection.java
字号:
/** * We return true for basically everything except for Missing class values, * because we can't really answer for all the models in our library. If any of * them don't work with the supplied data then we just trap the exception. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // returns the object // from // weka.classifiers.Classifier // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); result.enable(Capability.BINARY_ATTRIBUTES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.BINARY_CLASS); return result; } /** <!-- options-start --> * Valid options are: <p/> * * <pre> -L </path/to/modelLibrary> * Specifies the Model Library File, continuing the list of all models.</pre> * * <pre> -W </path/to/working/directory> * Specifies the Working Directory, where all models will be stored.</pre> * * <pre> -B <numModelBags> * Set the number of bags, i.e., number of iterations to run * the ensemble selection algorithm.</pre> * * <pre> -E <modelRatio> * Set the ratio of library models that will be randomly chosen * to populate each bag of models.</pre> * * <pre> -V <validationRatio> * Set the ratio of the training data set that will be reserved * for validation.</pre> * * <pre> -H <hillClimbIterations> * Set the number of hillclimbing iterations to be performed * on each model bag.</pre> * * <pre> -I <sortInitialization> * Set the the ratio of the ensemble library that the sort * initialization algorithm will be able to choose from while * initializing the ensemble for each model bag</pre> * * <pre> -X <numFolds> * Sets the number of cross-validation folds.</pre> * * <pre> -P <hillclimbMettric> * Specify the metric that will be used for model selection * during the hillclimbing algorithm. * Valid metrics are: * accuracy, rmse, roc, precision, recall, fscore, all</pre> * * <pre> -A <algorithm> * Specifies the algorithm to be used for ensemble selection. * Valid algorithms are: * "forward" (default) for forward selection. * "backward" for backward elimination. * "both" for both forward and backward elimination. * "best" to simply print out top performer from the * ensemble library * "library" to only train the models in the ensemble * library</pre> * * <pre> -R * Flag whether or not models can be selected more than once * for an ensemble.</pre> * * <pre> -G * Whether sort initialization greedily stops adding models * when performance degrades.</pre> * * <pre> -O * Flag for verbose output. Prints out performance of all * selected models.</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * @param options * the list of options as an array of strings * @throws Exception * if an option is not supported */ public void setOptions(String[] options) throws Exception { String tmpStr; tmpStr = Utils.getOption('L', options); if (tmpStr.length() != 0) { m_modelLibraryFileName = tmpStr; m_library = new EnsembleSelectionLibrary(m_modelLibraryFileName); } else { setLibrary(new EnsembleSelectionLibrary()); // setLibrary(new Library(super.m_Classifiers)); } tmpStr = Utils.getOption('W', options); if (tmpStr.length() != 0 && validWorkingDirectory(tmpStr)) { m_workingDirectory = new File(tmpStr); } else { m_workingDirectory = new File(getDefaultWorkingDirectory()); } m_library.setWorkingDirectory(m_workingDirectory); tmpStr = Utils.getOption('E', options); if (tmpStr.length() != 0) { setModelRatio(Double.parseDouble(tmpStr)); } else { setModelRatio(1.0); } tmpStr = Utils.getOption('V', options); if (tmpStr.length() != 0) { setValidationRatio(Double.parseDouble(tmpStr)); } else { setValidationRatio(0.25); } tmpStr = Utils.getOption('B', options); if (tmpStr.length() != 0) { setNumModelBags(Integer.parseInt(tmpStr)); } else { setNumModelBags(10); } tmpStr = Utils.getOption('H', options); if (tmpStr.length() != 0) { setHillclimbIterations(Integer.parseInt(tmpStr)); } else { setHillclimbIterations(100); } tmpStr = Utils.getOption('I', options); if (tmpStr.length() != 0) { setSortInitializationRatio(Double.parseDouble(tmpStr)); } else { setSortInitializationRatio(1.0); } tmpStr = Utils.getOption('X', options); if (tmpStr.length() != 0) { setNumFolds(Integer.parseInt(tmpStr)); } else { setNumFolds(10); } setReplacement(Utils.getFlag('R', options)); setGreedySortInitialization(Utils.getFlag('G', options)); setVerboseOutput(Utils.getFlag('O', options)); tmpStr = Utils.getOption('P', options); // if (hillclimbMetricString.length() != 0) { if (tmpStr.toLowerCase().equals("accuracy")) { setHillclimbMetric(new SelectedTag( EnsembleMetricHelper.METRIC_ACCURACY, TAGS_METRIC)); } else if (tmpStr.toLowerCase().equals("rmse")) { setHillclimbMetric(new SelectedTag( EnsembleMetricHelper.METRIC_RMSE, TAGS_METRIC)); } else if (tmpStr.toLowerCase().equals("roc")) { setHillclimbMetric(new SelectedTag( EnsembleMetricHelper.METRIC_ROC, TAGS_METRIC)); } else if (tmpStr.toLowerCase().equals("precision")) { setHillclimbMetric(new SelectedTag( EnsembleMetricHelper.METRIC_PRECISION, TAGS_METRIC)); } else if (tmpStr.toLowerCase().equals("recall")) { setHillclimbMetric(new SelectedTag( EnsembleMetricHelper.METRIC_RECALL, TAGS_METRIC)); } else if (tmpStr.toLowerCase().equals("fscore")) { setHillclimbMetric(new SelectedTag( EnsembleMetricHelper.METRIC_FSCORE, TAGS_METRIC)); } else if (tmpStr.toLowerCase().equals("all")) { setHillclimbMetric(new SelectedTag( EnsembleMetricHelper.METRIC_ALL, TAGS_METRIC)); } else { setHillclimbMetric(new SelectedTag( EnsembleMetricHelper.METRIC_RMSE, TAGS_METRIC)); } tmpStr = Utils.getOption('A', options); if (tmpStr.toLowerCase().equals("forward")) { setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM)); } else if (tmpStr.toLowerCase().equals("backward")) { setAlgorithm(new SelectedTag(ALGORITHM_BACKWARD, TAGS_ALGORITHM)); } else if (tmpStr.toLowerCase().equals("both")) { setAlgorithm(new SelectedTag(ALGORITHM_FORWARD_BACKWARD, TAGS_ALGORITHM)); } else if (tmpStr.toLowerCase().equals("forward")) { setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM)); } else if (tmpStr.toLowerCase().equals("best")) { setAlgorithm(new SelectedTag(ALGORITHM_BEST, TAGS_ALGORITHM)); } else if (tmpStr.toLowerCase().equals("library")) { setAlgorithm(new SelectedTag(ALGORITHM_BUILD_LIBRARY, TAGS_ALGORITHM)); } else { setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM)); } super.setOptions(options); m_library.setDebug(m_Debug); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { Vector result; String[] options; int i; result = new Vector(); if (m_library.getModelListFile() != null) { result.add("-L"); result.add("" + m_library.getModelListFile()); } if (!m_workingDirectory.equals("")) { result.add("-W"); result.add("" + getWorkingDirectory()); } result.add("-P"); switch (getHillclimbMetric().getSelectedTag().getID()) { case (EnsembleMetricHelper.METRIC_ACCURACY): result.add("accuracy"); break; case (EnsembleMetricHelper.METRIC_RMSE): result.add("rmse"); break; case (EnsembleMetricHelper.METRIC_ROC): result.add("roc"); break; case (EnsembleMetricHelper.METRIC_PRECISION): result.add("precision"); break; case (EnsembleMetricHelper.METRIC_RECALL): result.add("recall"); break; case (EnsembleMetricHelper.METRIC_FSCORE): result.add("fscore"); break; case (EnsembleMetricHelper.METRIC_ALL): result.add("all"); break; } result.add("-A"); switch (getAlgorithm().getSelectedTag().getID()) { case (ALGORITHM_FORWARD): result.add("forward"); break; case (ALGORITHM_BACKWARD): result.add("backward"); break; case (ALGORITHM_FORWARD_BACKWARD): result.add("both"); break; case (ALGORITHM_BEST): result.add("best"); break; case (ALGORITHM_BUILD_LIBRARY): result.add("library"); break; } result.add("-B"); result.add("" + getNumModelBags()); result.add("-V"); result.add("" + getValidationRatio()); result.add("-E"); result.add("" + getModelRatio()); result.add("-H"); result.add("" + getHillclimbIterations()); result.add("-I"); result.add("" + getSortInitializationRatio()); result.add("-X"); result.add("" + getNumFolds()); if (m_replacement) result.add("-R"); if (m_greedySortInitialization) result.add("-G"); if (m_verboseOutput) result.add("-O"); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); return (String[]) result.toArray(new String[result.size()]); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String numFoldsTipText() { return "The number of folds used for cross-validation."; } /** * Gets the number of folds for the cross-validation. * * @return the number of folds for the cross-validation */ public int getNumFolds() { return m_NumFolds; } /** * Sets the number of folds for the cross-validation. * * @param numFolds * the number of folds for the cross-validation * @throws Exception * if parameter illegal */ public void setNumFolds(int numFolds) throws Exception { if (numFolds < 0) { throw new IllegalArgumentException( "EnsembleSelection: Number of cross-validation " + "folds must be positive."); } m_NumFolds = numFolds; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String libraryTipText() { return "An ensemble library."; } /** * Gets the ensemble library. * * @return the ensemble library */ public EnsembleSelectionLibrary getLibrary() { return m_library; } /** * Sets the ensemble library. * * @param newLibrary * the ensemble library */ public void setLibrary(EnsembleSelectionLibrary newLibrary) { m_library = newLibrary; m_library.setDebug(m_Debug); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String modelRatioTipText() { return "The ratio of library models that will be randomly chosen to be used for each iteration."; } /** * Get the value of modelRatio. * * @return Value of modelRatio. */ public double getModelRatio() { return m_modelRatio; } /** * Set the value of modelRatio. * * @param v * Value to assign to modelRatio. */ public void setModelRatio(double v) { m_modelRatio = v; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String validationRatioTipText() { return "The ratio of the training data set that will be reserved for validation."; } /** * Get the value of validationRatio. * * @return Value of validationRatio. */ public double getValidationRatio() { return m_validationRatio; } /** * Set the value of validationRatio. * * @param v * Value to assign to validationRatio. */ public void setValidationRatio(double v) { m_validationRatio = v; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -