📄 cvparameterselection.java
字号:
"X", 1, "-X <number of folds>")); newVector.addElement(new Option( "\tClassifier parameter options.\n" + "\teg: \"N 1 5 10\" Sets an optimisation parameter for the\n" + "\tclassifier with name -N, with lower bound 1, upper bound\n" + "\t5, and 10 optimisation steps. The upper bound may be the\n" + "\tcharacter 'A' or 'I' to substitute the number of\n" + "\tattributes or instances in the training data,\n" + "\trespectively. This parameter may be supplied more than\n" + "\tonce to optimise over several classifier options\n" + "\tsimultaneously.", "P", 1, "-P <classifier parameter>")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -X <number of folds> * Number of folds used for cross validation (default 10).</pre> * * <pre> -P <classifier parameter> * Classifier parameter options. * eg: "N 1 5 10" Sets an optimisation parameter for the * classifier with name -N, with lower bound 1, upper bound * 5, and 10 optimisation steps. The upper bound may be the * character 'A' or 'I' to substitute the number of * attributes or instances in the training data, * respectively. This parameter may be supplied more than * once to optimise over several classifier options * simultaneously.</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.rules.ZeroR)</pre> * * <pre> * Options specific to classifier weka.classifiers.rules.ZeroR: * </pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * Options after -- are passed to the designated sub-classifier. <p> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String foldsString = Utils.getOption('X', options); if (foldsString.length() != 0) { setNumFolds(Integer.parseInt(foldsString)); } else { setNumFolds(10); } String cvParam; m_CVParams = new FastVector(); do { cvParam = Utils.getOption('P', options); if (cvParam.length() != 0) { addCVParameter(cvParam); } } while (cvParam.length() != 0); super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String[] superOptions; if (m_InitOptions != null) { try { m_Classifier.setOptions((String[])m_InitOptions.clone()); superOptions = super.getOptions(); m_Classifier.setOptions((String[])m_BestClassifierOptions.clone()); } catch (Exception e) { throw new RuntimeException("CVParameterSelection: could not set options " + "in getOptions()."); } } else { superOptions = super.getOptions(); } String [] options = new String [superOptions.length + m_CVParams.size() * 2 + 2]; int current = 0; for (int i = 0; i < m_CVParams.size(); i++) { options[current++] = "-P"; options[current++] = "" + getCVParameter(i); } options[current++] = "-X"; options[current++] = "" + getNumFolds(); System.arraycopy(superOptions, 0, options, current, superOptions.length); return options; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.setMinimumNumberInstances(m_NumFolds); return result; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class Instances trainData = new Instances(instances); trainData.deleteWithMissingClass(); if (!(m_Classifier instanceof OptionHandler)) { throw new IllegalArgumentException("Base classifier should be OptionHandler."); } m_InitOptions = ((OptionHandler)m_Classifier).getOptions(); m_BestPerformance = -99; m_NumAttributes = trainData.numAttributes(); Random random = new Random(m_Seed); trainData.randomize(random); m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances(); // Check whether there are any parameters to optimize if (m_CVParams.size() == 0) { m_Classifier.buildClassifier(trainData); m_BestClassifierOptions = m_InitOptions; return; } if (trainData.classAttribute().isNominal()) { trainData.stratify(m_NumFolds); } m_BestClassifierOptions = null; // Set up m_ClassifierOptions -- take getOptions() and remove // those being optimised. m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions(); for (int i = 0; i < m_CVParams.size(); i++) { Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions); } findParamsByCrossValidation(0, trainData, random); String [] options = (String [])m_BestClassifierOptions.clone(); ((OptionHandler)m_Classifier).setOptions(options); m_Classifier.buildClassifier(trainData); } /** * Predicts the class distribution for the given test instance. * * @param instance the instance to be classified * @return the predicted class value * @throws Exception if an error occurred during the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { return m_Classifier.distributionForInstance(instance); } /** * Adds a scheme parameter to the list of parameters to be set * by cross-validation * * @param cvParam the string representation of a scheme parameter. The * format is: <br> * param_char lower_bound upper_bound number_of_steps <br> * eg to search a parameter -P from 1 to 10 by increments of 1: <br> * P 1 10 11 <br> * @throws Exception if the parameter specifier is of the wrong format */ public void addCVParameter(String cvParam) throws Exception { CVParameter newCV = new CVParameter(cvParam); m_CVParams.addElement(newCV); } /** * Gets the scheme paramter with the given index. * * @param index the index for the parameter * @return the scheme parameter */ public String getCVParameter(int index) { if (m_CVParams.size() <= index) { return ""; } return ((CVParameter)m_CVParams.elementAt(index)).toString(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String CVParametersTipText() { return "Sets the scheme parameters which are to be set "+ "by cross-validation.\n"+ "The format for each string should be:\n"+ "param_char lower_bound upper_bound number_of_steps\n"+ "eg to search a parameter -P from 1 to 10 by increments of 1:\n"+ " \"P 1 10 11\" "; } /** * Get method for CVParameters. * * @return the CVParameters */ public Object[] getCVParameters() { Object[] CVParams = m_CVParams.toArray(); String params[] = new String[CVParams.length]; for(int i=0; i<CVParams.length; i++) params[i] = CVParams[i].toString(); return params; } /** * Set method for CVParameters. * * @param params the CVParameters to use * @throws Exception if the setting of the CVParameters fails */ public void setCVParameters(Object[] params) throws Exception { FastVector backup = m_CVParams; m_CVParams = new FastVector(); for(int i=0; i<params.length; i++) { try{ addCVParameter((String)params[i]); } catch(Exception ex) { m_CVParams = backup; throw ex; } } } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numFoldsTipText() { return "Get the number of folds used for cross-validation."; } /** * Gets the number of folds for the cross-validation. * * @return the number of folds for the cross-validation */ public int getNumFolds() { return m_NumFolds; } /** * Sets the number of folds for the cross-validation. * * @param numFolds the number of folds for the cross-validation * @throws Exception if parameter illegal */ public void setNumFolds(int numFolds) throws Exception { if (numFolds < 0) { throw new IllegalArgumentException("Stacking: Number of cross-validation " + "folds must be positive."); } m_NumFolds = numFolds; } /** * Returns the type of graph this classifier * represents. * * @return the type of graph this classifier represents */ public int graphType() { if (m_Classifier instanceof Drawable) return ((Drawable)m_Classifier).graphType(); else return Drawable.NOT_DRAWABLE; } /** * Returns graph describing the classifier (if possible). * * @return the graph of the classifier in dotty format * @throws Exception if the classifier cannot be graphed */ public String graph() throws Exception { if (m_Classifier instanceof Drawable) return ((Drawable)m_Classifier).graph(); else throw new Exception("Classifier: " + m_Classifier.getClass().getName() + " " + Utils.joinOptions(m_BestClassifierOptions) + " cannot be graphed"); } /** * Returns description of the cross-validated classifier. * * @return description of the cross-validated classifier as a string */ public String toString() { if (m_InitOptions == null) return "CVParameterSelection: No model built yet."; String result = "Cross-validated Parameter selection.\n" + "Classifier: " + m_Classifier.getClass().getName() + "\n"; try { for (int i = 0; i < m_CVParams.size(); i++) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i); result += "Cross-validation Parameter: '-" + cvParam.m_ParamChar + "'" + " ranged from " + cvParam.m_Lower + " to "; switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: result += m_NumAttributes; break; case 2: result += m_TrainFoldSize; break; default: result += cvParam.m_Upper; break; } result += " with " + cvParam.m_Steps + " steps\n"; } } catch (Exception ex) { result += ex.getMessage(); } result += "Classifier Options: " + Utils.joinOptions(m_BestClassifierOptions) + "\n\n" + m_Classifier.toString(); return result; } /** * A concise description of the model. * * @return a concise description of the model */ public String toSummaryString() { String result = "Selected values: " + Utils.joinOptions(m_BestClassifierOptions); return result + '\n'; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { runClassifier(new CVParameterSelection(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -