⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 evaluation.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
   {       EvaluationClient ec;       // Some preliminary error checking       if(numFolds < 2)       {           throw new IllegalArgumentException(                     "Number of folds must be at least 2!");       }       if (numFolds > data.numInstances())       {           throw new IllegalArgumentException(                     "Can't have more folds than instances!");       }       ec = new EvaluationClient(numFolds, data, classifier,                                 otherComputers, this);       ec.start();       m_NumFolds = numFolds;   }  /**   * Performs a (stratified if class is nominal) cross-validation   * for a classifier on a set of instances.   *   * @param classifier a string naming the class of the classifier   * @param data the data on which the cross-validation is to be   * performed   * @param numFolds the number of folds for the cross-validation   * @param options the options to the classifier. Any options   * accepted by the classifier will be removed from this array.   * @exception Exception if a classifier could not be generated   * successfully or the class is not defined   */  public void crossValidateModel(String classifierString,                                 Instances data, int numFolds,                                 String[] options)    throws Exception  {    boolean runInParallel = Utils.getFlag('a', options);    if(runInParallel)      crossValidateModelParallel(Classifier.forName(classifierString, options),                                 data, numFolds, m_OtherComputers);    else      crossValidateModel(Classifier.forName(classifierString, options),                         data, numFolds);  }  /**   * Evaluates a classifier with the options given in an array of   * strings. <p>   *   * Valid options are: <p>   *   * -t filename <br>   * Name of the file with the training data. (required) <p>   *   * -T filename <br>   * Name of the file with the test data. If missing a cross-validation   * is performed. <p>   *   * -c index <br>   * Index of the class attribute (1, 2, ...; default: last). <p>   *   * -x number <br>   * The number of folds for the cross-validation (default: 10). <p>   *   * -s seed <br>   * Random number seed for the cross-validation (default: 1). <p>   *   * -m filename <br>   * The name of a file containing a cost matrix. <p>   *   * -l filename <br>   * Loads classifier from the given file. <p>   *   * -d filename <br>   * Saves classifier built from the training data into the given file. <p>   *   * -v <br>   * Outputs no statistics for the training data. <p>   *   * -o <br>   * Outputs statistics only, not the classifier. <p>   *   * -i <br>   * Outputs detailed information-retrieval statistics per class. <p>   *   * -k <br>   * Outputs information-theoretic statistics. <p>   *   * -p range <br>   * Outputs predictions for test instances, along with the attributes in   * the specified range (and nothing else). Use '-p 0' if no attributes are   * desired. <p>   *   * -r <br>   * Outputs cumulative margin distribution (and nothing else). <p>   *   * -g <br>   * Only for classifiers that implement "Graphable." Outputs   * the graph representation of the classifier (and nothing   * else). <p>   *   * @param classifierString class of machine learning classifier as a string   * @param options the array of string containing the options   * @exception Exception if model could not be evaluated successfully   * @return a string describing the results   */  public static String evaluateModel(String classifierString,                                     String [] options) throws Exception {    Classifier classifier;    // Create classifier    try {      classifier =      (Classifier)Class.forName(classifierString).newInstance();    } catch (Exception e) {      throw new Exception("Can't find class with name "                          + classifierString + '.');    }    return evaluateModel(classifier, options);  }  /**   * A test method for this class. Just extracts the first command line   * argument as a classifier class name and calls evaluateModel.   * @param args an array of command line arguments, the first of which   * must be the class name of a classifier.   */  public static void main(String [] args) {    try {      if (args.length == 0) {        throw new Exception("The first argument must be the class name"                            + " of a classifier");      }      String classifier = args[0];      args[0] = "";      System.out.println(evaluateModel(classifier, args));    } catch (Exception ex) {      ex.printStackTrace();      System.err.println(ex.getMessage());    }  }  /**   * Evaluates a classifier with the options given in an array of   * strings. <p>   *   * Valid options are: <p>   *   * -t name of training file <br>   * Name of the file with the training data. (required) <p>   *   * -T name of test file <br>   * Name of the file with the test data. If missing a cross-validation   * is performed. <p>   *   * -c class index <br>   * Index of the class attribute (1, 2, ...; default: last). <p>   *   * -x number of folds <br>   * The number of folds for the cross-validation (default: 10). <p>   *   * -s random number seed <br>   * Random number seed for the cross-validation (default: 1). <p>   *   * -m file with cost matrix <br>   * The name of a file containing a cost matrix. <p>   *   * -l name of model input file <br>   * Loads classifier from the given file. <p>   *   * -d name of model output file <br>   * Saves classifier built from the training data into the given file. <p>   *   * -v <br>   * Outputs no statistics for the training data. <p>   *   * -o <br>   * Outputs statistics only, not the classifier. <p>   *   * -i <br>   * Outputs detailed information-retrieval statistics per class. <p>   *   * -k <br>   * Outputs information-theoretic statistics. <p>   *   * -p <br>   * Outputs predictions for test instances (and nothing else). <p>   *   * -r <br>   * Outputs cumulative margin distribution (and nothing else). <p>   *   * -g <br>   * Only for classifiers that implement "Graphable." Outputs   * the graph representation of the classifier (and nothing   * else). <p>   *   * -a <br>   * Runs the cross-validation in parallel. <p>   *   * @param classifier machine learning classifier   * @param options the array of string containing the options   * @exception Exception if model could not be evaluated successfully   * @return a string describing the results */  public static String evaluateModel(Classifier classifier,                                     String [] options) throws Exception {    Instances train = null, tempTrain, test = null, template = null;    int seed = 1, folds = 10, classIndex = -1;    String trainFileName, testFileName, sourceClass,      classIndexString, seedString, foldsString, objectInputFileName,      objectOutputFileName, attributeRangeString;    boolean IRstatistics = false, noOutput = false,      printClassifications = false, trainStatistics = true,      printMargins = false, printComplexityStatistics = false,      printGraph = false, classStatistics = false, printSource = false;    StringBuffer text = new StringBuffer();    BufferedReader trainReader = null, testReader = null;    ObjectInputStream objectInputStream = null;    Random random;    CostMatrix costMatrix = null;    StringBuffer schemeOptionsText = null;    Range attributesToOutput = null;    long trainTimeStart = 0, trainTimeElapsed = 0,      testTimeStart = 0, testTimeElapsed = 0;    // Added by Sebastian Celis for parallelization    boolean runInParallel;    String parallelConfigLocation;    File parallelConfigFile;    StringBuffer otherComputers = new StringBuffer();    try    {      // Added by Sebastian Celis for parallelization      runInParallel = Utils.getFlag('a', options);      if(runInParallel)      {          parallelConfigLocation = System.getProperty("user.home");          // If the user is running windows...          if(System.getProperty("os.name").charAt(0) == 'W')          {              parallelConfigLocation                  = parallelConfigLocation.concat("\\.weka-parallel");          }          // If the user is running anything else          else          {              parallelConfigLocation                  = parallelConfigLocation.concat("/.weka-parallel");          }          parallelConfigFile = new File(parallelConfigLocation);          if(!parallelConfigFile.exists())          {              throw new Exception("Config file for parallelization does "+                                  "not exist.");          }      }      // Get basic options (options the same for all schemes)      classIndexString = Utils.getOption('c', options);      if (classIndexString.length() != 0) {        classIndex = Integer.parseInt(classIndexString);      }      trainFileName = Utils.getOption('t', options);      objectInputFileName = Utils.getOption('l', options);      objectOutputFileName = Utils.getOption('d', options);      testFileName = Utils.getOption('T', options);      if (trainFileName.length() == 0) {        if (objectInputFileName.length() == 0) {          throw new Exception("No training file and no object "+                              "input file given.");        }        if (testFileName.length() == 0) {          throw new Exception("No training file and no test "+                              "file given.");        }      } else if ((objectInputFileName.length() != 0) &&                 ((!(classifier instanceof UpdateableClassifier)) ||                 (testFileName.length() == 0))) {        throw new Exception("Classifier not incremental, or no " +                            "test file provided: can't "+                            "use both train and model file.");      }      try {        if (trainFileName.length() != 0) {          trainReader = new BufferedReader(new FileReader(trainFileName));        }        if (testFileName.length() != 0) {          testReader = new BufferedReader(new FileReader(testFileName));        }        if (objectInputFileName.length() != 0) {          InputStream is = new FileInputStream(objectInputFileName);          if (objectInputFileName.endsWith(".gz")) {            is = new GZIPInputStream(is);          }          objectInputStream = new ObjectInputStream(is);        }      } catch (Exception e) {        throw new Exception("Can't open file " + e.getMessage() + '.');      }      if (testFileName.length() != 0) {        template = test = new Instances(testReader, 1);        if (classIndex != -1) {          test.setClassIndex(classIndex - 1);        } else {          test.setClassIndex(test.numAttributes() - 1);        }        if (classIndex > test.numAttributes()) {          throw new Exception("Index of class attribute too large.");        }      }      if (trainFileName.length() != 0) {        if ((classifier instanceof UpdateableClassifier) &&            (testFileName.length() != 0)) {          train = new Instances(trainReader, 1);        } else {          train = new Instances(trainReader);        }        template = train;        if (classIndex != -1) {          train.setClassIndex(classIndex - 1);        } else {          train.setClassIndex(train.numAttributes() - 1);        }        if (classIndex > train.numAttributes()) {          throw new Exception("Index of class attribute too large.");        }        //train = new Instances(train);      }      if (template == null) {        throw new Exception("No actual dataset provided to use as template");      }      seedString = Utils.getOption('s', options);      if (seedString.length() != 0) {        seed = Integer.parseInt(seedString);      }      foldsString = Utils.getOption('x', options);      if (foldsString.length() != 0) {        folds = Integer.parseInt(foldsString);      }      costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses());      classStatistics = Utils.getFlag('i', options);      noOutput = Utils.getFlag('o', options);      trainStatistics = !Utils.getFlag('v', options);      printComplexityStatistics = Utils.getFlag('k', options);      printMargins = Utils.getFlag('r', options);      printGraph = Utils.getFlag('g', options);      sourceClass = Utils.getOption('z', options);      printSource = (sourceClass.length() != 0);      // Check -p option      try {        attributeRangeString = Utils.getOption('p', options);      }      catch (Exception e) {        throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +                            "It now expects a parameter specifying a range of attributes " +                            "to list with the predictions. Use '-p 0' for none.");      }      if (attributeRangeString.length() != 0) {        printClassifications = true;        if (!attributeRangeString.equals("0"))          attributesToOutput = new Range(attributeRangeString);      }      // If a model file is given, we can't process      // scheme-specific options      if (objectInputFileName.length() != 0) {        Utils.checkForRemainingOptions(options);      } else {        // Set options for classifier        if (classifier instanceof OptionHandler) {          for (int i = 0; i < options.length; i++) {            if (options[i].length() != 0) {              if (schemeOptionsText == null) {                schemeOptionsText = new StringBuffer();              }              if (options[i].indexOf(' ') != -1) {                schemeOptionsText.append('"' + options[i] + "\" ");              } else {                schemeOptionsText.append(options[i] + " ");              }            }          }          ((OptionHandler)classifier).setOptions(options);        }      }      Utils.checkForRemainingOptions(options);    } catch (Exception e) {      throw new Exception("\nWeka exception: " + e.getMessage()                           + makeOptionString(classifier));    }    // Setup up evaluation objects    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);    if (objectInputFileName.length() != 0) {      // Load classifier from file      classifier = (Classifier) objectInputStream.readObject();      objectInputStream.close();

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -