⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 seqcvparameterselection.java

📁 把 sequential 有导师学习问题转化为传统的有导师学习问题
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
      }    }    return newVector.elements();  }  /**   * Parses a given list of options. Valid options are:<p>   *   * -D <br>   * Turn on debugging output.<p>   *   * -W classname <br>   * Specify the full class name of classifier to perform cross-validation   * selection on.<p>   *   * -X num <br>   * Number of folds used for cross validation (default 10). <p>   *   * -S seed <br>   * Random number seed (default 1).<p>   *   * -P "N 1 5 10" <br>   * Sets an optimisation parameter for the classifier with name -N,   * lower bound 1, upper bound 5, and 10 optimisation steps.   * The upper bound may be the character 'A' or 'I' to substitute    * the number of attributes or instances in the training data,   * respectively.   * This parameter may be supplied more than once to optimise over   * several classifier options simultaneously. <p>   *   * Options after -- are passed to the designated sub-classifier. <p>   *   * @param options the list of options as an array of strings   * @exception Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {        setDebug(Utils.getFlag('D', options));    String foldsString = Utils.getOption('X', options);    if (foldsString.length() != 0) {      setNumFolds(Integer.parseInt(foldsString));    } else {      setNumFolds(10);    }    String randomString = Utils.getOption('S', options);    if (randomString.length() != 0) {      setSeed(Integer.parseInt(randomString));    } else {      setSeed(1);    }    String cvParam;    m_CVParams = new FastVector();    do {      cvParam = Utils.getOption('P', options);      if (cvParam.length() != 0) {	addCVParameter(cvParam);      }    } while (cvParam.length() != 0);    if (m_CVParams.size() == 0) {      throw new Exception("A parameter specifier must be given with"			  + " the -P option.");    }    String classifierName = Utils.getOption('W', options);    if (classifierName.length() == 0) {      throw new Exception("A classifier must be specified with"			  + " the -W option.");    }    setClassifier((SequentialClassifier)Classifier.forName(classifierName,				     Utils.partitionOptions(options)));    if (!(m_Classifier instanceof OptionHandler)) {      throw new Exception("Base classifier must accept options");    }  }  /**   * Gets the current settings of the Classifier.   *   * @return an array of strings suitable for passing to setOptions   */  public String [] getOptions() {    String [] classifierOptions = new String [0];    if ((m_Classifier != null) && 	(m_Classifier instanceof OptionHandler)) {      classifierOptions = ((OptionHandler)m_Classifier).getOptions();    }    int current = 0;    String [] options = new String [classifierOptions.length + 8];    if (m_CVParams != null) {      options = new String [m_CVParams.size() * 2 + options.length];      for (int i = 0; i < m_CVParams.size(); i++) {	options[current++] = "-P"; options[current++] = "" + getCVParameter(i);      }    }    if (getDebug()) {      options[current++] = "-D";    }    options[current++] = "-X"; options[current++] = "" + getNumFolds();    options[current++] = "-S"; options[current++] = "" + getSeed();    if (getClassifier() != null) {      options[current++] = "-W";      options[current++] = getClassifier().getClass().getName();    }    options[current++] = "--";    System.arraycopy(classifierOptions, 0, options, current, 		     classifierOptions.length);    current += classifierOptions.length;    while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Generates the classifier.   *   * @param instances set of instances serving as training data    * @exception Exception if the classifier has not been generated successfully   */  public void buildClassifier(Instances instances) throws Exception {    if (instances.checkForStringAttributes()) {      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");    }    Instances trainData = new Instances(instances);    trainData.deleteWithMissingClass();    if (trainData.numInstances() == 0) {      throw new Exception("No training instances without missing class.");    }    if (trainData.numInstances() < m_NumFolds) {      throw new Exception("Number of training instances smaller than number of folds.");    }    // Check whether there are any parameters to optimize    if (m_CVParams == null) {       m_Classifier.buildClassifier(trainData);       return;    }    if(!(m_Classifier instanceof SequentialClassifier)) {	      trainData.randomize(new Random(m_Seed));      if (trainData.classAttribute().isNominal()) {        trainData.stratify(m_NumFolds);      }      m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();    } else {      m_TrainFoldSize = trainData.seqTrainCV(m_NumFolds, 0).numInstances();     }      m_BestPerformance = -99;    m_BestClassifierOptions = null;    m_NumAttributes = trainData.numAttributes();            // Set up m_ClassifierOptions -- take getOptions() and remove    // those being optimised.    m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions();    for (int i = 0; i < m_CVParams.size(); i++) {      Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar,		      m_ClassifierOptions);    }    findParamsByCrossValidation(0, trainData);    String [] options = (String [])m_BestClassifierOptions.clone();    ((OptionHandler)m_Classifier).setOptions(options);    m_Classifier.buildClassifier(trainData);  }  /**   * Predicts the class value for the given test instance.   *   * @param instance the instance to be classified   * @return the predicted class value   * @exception Exception if an error occurred during the prediction   *//*  public double classifyInstance(Instance instance) throws Exception {     return m_Classifier.classifyInstance(instance);  }*/  /**   * Predicts the class value for the given test sequence.   *   * @param instances the sequence to be classified   * @return the predicted class value   * @exception Exception if an error occurred during the prediction   */  public double [] classifySequence(Instances instances) throws Exception {     return m_Classifier.classifySequence(instances);  }  /**   * Sets the seed for random number generation.   *   * @param seed the random number seed   */  public void setSeed(int seed) {        m_Seed = seed;  }  /**   * Gets the random number seed.   *    * @return the random number seed   */  public int getSeed() {    return m_Seed;  }  /**   * Adds a scheme parameter to the list of parameters to be set   * by cross-validation   *   * @param cvParam the string representation of a scheme parameter. The   * format is: <br>   * param_char lower_bound upper_bound increment <br>   * eg to search a parameter -P from 1 to 10 by increments of 2: <br>   * P 1 10 2 <br>   * @exception Exception if the parameter specifier is of the wrong format   */  public void addCVParameter(String cvParam) throws Exception {        CVParameter newCV = new CVParameter(cvParam);        m_CVParams.addElement(newCV);  }  /**   * Gets the scheme paramter with the given index.   */  public String getCVParameter(int index) {    if (m_CVParams.size() <= index) {      return "";    }    return ((CVParameter)m_CVParams.elementAt(index)).toString();  }  /**   * Sets debugging mode   *   * @param debug true if debug output should be printed   */  public void setDebug(boolean debug) {    m_Debug = debug;  }  /**   * Gets whether debugging is turned on   *   * @return true if debugging output is on   */  public boolean getDebug() {    return m_Debug;  }  /**   * Get the number of folds used for cross-validation.   *   * @return the number of folds used for cross-validation.   */  public int getNumFolds() {        return m_NumFolds;  }    /**   * Set the number of folds used for cross-validation.   *   * @param newNumFolds the number of folds used for cross-validation.   */  public void setNumFolds(int newNumFolds) {        m_NumFolds = newNumFolds;  }  /**   * Set the classifier for boosting.    *   * @param newClassifier the Classifier to use.   */  public void setClassifier(SequentialClassifier newClassifier) {    m_Classifier = newClassifier;  }  /**   * Get the classifier used as the classifier   *   * @return the classifier used as the classifier   */  public Classifier getClassifier() {    return m_Classifier;  }   /**   * Returns description of the cross-validated classifier.   *   * @return description of the cross-validated classifier as a string   */  public String toString() {    if (m_BestClassifierOptions == null)      return "CVParameterSelection: No model built yet.";    String result = "Cross-validated Parameter selection.\n"    + "Classifier: " + m_Classifier.getClass().getName() + "\n";    try {      for (int i = 0; i < m_CVParams.size(); i++) {	CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);	result += "Cross-validation Parameter: '-" 	  + cvParam.m_ParamChar + "'"	  + " ranged from " + cvParam.m_Lower 	  + " to ";	switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {	case 1:	  result += m_NumAttributes;	  break;	case 2:	  result += m_TrainFoldSize;	  break;	default:	  result += cvParam.m_Upper;	  break;	}	result += " with " + cvParam.m_Steps + " steps\n";      }    } catch (Exception ex) {      result += ex.getMessage();    }    result += "Classifier Options: "      + Utils.joinOptions(m_BestClassifierOptions)      + "\n\n" + m_Classifier.toString();    return result;  }  public String toSummaryString() {    String result = "Selected values: "      + Utils.joinOptions(m_BestClassifierOptions);    return result + '\n';  }    /**   * Main method for testing this class.   *   * @param argv the options   */  public static void main(String [] argv) {    try {        System.out.println(SequentialEvaluation.evaluateModel(new SeqCVParameterSelection(), 						  argv));    } catch (Exception e) {        System.err.println(e.getMessage());	e.printStackTrace();    }  }  }  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -