📄 seqcvparameterselection.java
字号:
} } return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -D <br> * Turn on debugging output.<p> * * -W classname <br> * Specify the full class name of classifier to perform cross-validation * selection on.<p> * * -X num <br> * Number of folds used for cross validation (default 10). <p> * * -S seed <br> * Random number seed (default 1).<p> * * -P "N 1 5 10" <br> * Sets an optimisation parameter for the classifier with name -N, * lower bound 1, upper bound 5, and 10 optimisation steps. * The upper bound may be the character 'A' or 'I' to substitute * the number of attributes or instances in the training data, * respectively. * This parameter may be supplied more than once to optimise over * several classifier options simultaneously. <p> * * Options after -- are passed to the designated sub-classifier. <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { setDebug(Utils.getFlag('D', options)); String foldsString = Utils.getOption('X', options); if (foldsString.length() != 0) { setNumFolds(Integer.parseInt(foldsString)); } else { setNumFolds(10); } String randomString = Utils.getOption('S', options); if (randomString.length() != 0) { setSeed(Integer.parseInt(randomString)); } else { setSeed(1); } String cvParam; m_CVParams = new FastVector(); do { cvParam = Utils.getOption('P', options); if (cvParam.length() != 0) { addCVParameter(cvParam); } } while (cvParam.length() != 0); if (m_CVParams.size() == 0) { throw new Exception("A parameter specifier must be given with" + " the -P option."); } String classifierName = Utils.getOption('W', options); if (classifierName.length() == 0) { throw new Exception("A classifier must be specified with" + " the -W option."); } setClassifier((SequentialClassifier)Classifier.forName(classifierName, Utils.partitionOptions(options))); if (!(m_Classifier instanceof OptionHandler)) { throw new Exception("Base classifier must accept options"); } } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] classifierOptions = new String [0]; if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { classifierOptions = ((OptionHandler)m_Classifier).getOptions(); } int current = 0; String [] options = new String [classifierOptions.length + 8]; if (m_CVParams != null) { options = new String [m_CVParams.size() * 2 + options.length]; for (int i = 0; i < m_CVParams.size(); i++) { options[current++] = "-P"; options[current++] = "" + getCVParameter(i); } } if (getDebug()) { options[current++] = "-D"; } options[current++] = "-X"; options[current++] = "" + getNumFolds(); options[current++] = "-S"; options[current++] = "" + getSeed(); if (getClassifier() != null) { options[current++] = "-W"; options[current++] = getClassifier().getClass().getName(); } options[current++] = "--"; System.arraycopy(classifierOptions, 0, options, current, classifierOptions.length); current += classifierOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } Instances trainData = new Instances(instances); trainData.deleteWithMissingClass(); if (trainData.numInstances() == 0) { throw new Exception("No training instances without missing class."); } if (trainData.numInstances() < m_NumFolds) { throw new Exception("Number of training instances smaller than number of folds."); } // Check whether there are any parameters to optimize if (m_CVParams == null) { m_Classifier.buildClassifier(trainData); return; } if(!(m_Classifier instanceof SequentialClassifier)) { trainData.randomize(new Random(m_Seed)); if (trainData.classAttribute().isNominal()) { trainData.stratify(m_NumFolds); } m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances(); } else { m_TrainFoldSize = trainData.seqTrainCV(m_NumFolds, 0).numInstances(); } m_BestPerformance = -99; m_BestClassifierOptions = null; m_NumAttributes = trainData.numAttributes(); // Set up m_ClassifierOptions -- take getOptions() and remove // those being optimised. m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions(); for (int i = 0; i < m_CVParams.size(); i++) { Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions); } findParamsByCrossValidation(0, trainData); String [] options = (String [])m_BestClassifierOptions.clone(); ((OptionHandler)m_Classifier).setOptions(options); m_Classifier.buildClassifier(trainData); } /** * Predicts the class value for the given test instance. * * @param instance the instance to be classified * @return the predicted class value * @exception Exception if an error occurred during the prediction *//* public double classifyInstance(Instance instance) throws Exception { return m_Classifier.classifyInstance(instance); }*/ /** * Predicts the class value for the given test sequence. * * @param instances the sequence to be classified * @return the predicted class value * @exception Exception if an error occurred during the prediction */ public double [] classifySequence(Instances instances) throws Exception { return m_Classifier.classifySequence(instances); } /** * Sets the seed for random number generation. * * @param seed the random number seed */ public void setSeed(int seed) { m_Seed = seed; } /** * Gets the random number seed. * * @return the random number seed */ public int getSeed() { return m_Seed; } /** * Adds a scheme parameter to the list of parameters to be set * by cross-validation * * @param cvParam the string representation of a scheme parameter. The * format is: <br> * param_char lower_bound upper_bound increment <br> * eg to search a parameter -P from 1 to 10 by increments of 2: <br> * P 1 10 2 <br> * @exception Exception if the parameter specifier is of the wrong format */ public void addCVParameter(String cvParam) throws Exception { CVParameter newCV = new CVParameter(cvParam); m_CVParams.addElement(newCV); } /** * Gets the scheme paramter with the given index. */ public String getCVParameter(int index) { if (m_CVParams.size() <= index) { return ""; } return ((CVParameter)m_CVParams.elementAt(index)).toString(); } /** * Sets debugging mode * * @param debug true if debug output should be printed */ public void setDebug(boolean debug) { m_Debug = debug; } /** * Gets whether debugging is turned on * * @return true if debugging output is on */ public boolean getDebug() { return m_Debug; } /** * Get the number of folds used for cross-validation. * * @return the number of folds used for cross-validation. */ public int getNumFolds() { return m_NumFolds; } /** * Set the number of folds used for cross-validation. * * @param newNumFolds the number of folds used for cross-validation. */ public void setNumFolds(int newNumFolds) { m_NumFolds = newNumFolds; } /** * Set the classifier for boosting. * * @param newClassifier the Classifier to use. */ public void setClassifier(SequentialClassifier newClassifier) { m_Classifier = newClassifier; } /** * Get the classifier used as the classifier * * @return the classifier used as the classifier */ public Classifier getClassifier() { return m_Classifier; } /** * Returns description of the cross-validated classifier. * * @return description of the cross-validated classifier as a string */ public String toString() { if (m_BestClassifierOptions == null) return "CVParameterSelection: No model built yet."; String result = "Cross-validated Parameter selection.\n" + "Classifier: " + m_Classifier.getClass().getName() + "\n"; try { for (int i = 0; i < m_CVParams.size(); i++) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i); result += "Cross-validation Parameter: '-" + cvParam.m_ParamChar + "'" + " ranged from " + cvParam.m_Lower + " to "; switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: result += m_NumAttributes; break; case 2: result += m_TrainFoldSize; break; default: result += cvParam.m_Upper; break; } result += " with " + cvParam.m_Steps + " steps\n"; } } catch (Exception ex) { result += ex.getMessage(); } result += "Classifier Options: " + Utils.joinOptions(m_BestClassifierOptions) + "\n\n" + m_Classifier.toString(); return result; } public String toSummaryString() { String result = "Selected values: " + Utils.joinOptions(m_BestClassifierOptions); return result + '\n'; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { try { System.out.println(SequentialEvaluation.evaluateModel(new SeqCVParameterSelection(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -