📄 evaluation.java
字号:
{ EvaluationClient ec; // Some preliminary error checking if(numFolds < 2) { throw new IllegalArgumentException( "Number of folds must be at least 2!"); } if (numFolds > data.numInstances()) { throw new IllegalArgumentException( "Can't have more folds than instances!"); } ec = new EvaluationClient(numFolds, data, classifier, otherComputers, this); ec.start(); m_NumFolds = numFolds; } /** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. * * @param classifier a string naming the class of the classifier * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @param options the options to the classifier. Any options * accepted by the classifier will be removed from this array. * @exception Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(String classifierString, Instances data, int numFolds, String[] options) throws Exception { boolean runInParallel = Utils.getFlag('a', options); if(runInParallel) crossValidateModelParallel(Classifier.forName(classifierString, options), data, numFolds, m_OtherComputers); else crossValidateModel(Classifier.forName(classifierString, options), data, numFolds); } /** * Evaluates a classifier with the options given in an array of * strings. <p> * * Valid options are: <p> * * -t filename <br> * Name of the file with the training data. (required) <p> * * -T filename <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number <br> * The number of folds for the cross-validation (default: 10). <p> * * -s seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m filename <br> * The name of a file containing a cost matrix. <p> * * -l filename <br> * Loads classifier from the given file. <p> * * -d filename <br> * Saves classifier built from the training data into the given file. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> * * -i <br> * Outputs detailed information-retrieval statistics per class. <p> * * -k <br> * Outputs information-theoretic statistics. <p> * * -p range <br> * Outputs predictions for test instances, along with the attributes in * the specified range (and nothing else). Use '-p 0' if no attributes are * desired. <p> * * -r <br> * Outputs cumulative margin distribution (and nothing else). <p> * * -g <br> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * @param classifierString class of machine learning classifier as a string * @param options the array of string containing the options * @exception Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateModel(String classifierString, String [] options) throws Exception { Classifier classifier; // Create classifier try { classifier = (Classifier)Class.forName(classifierString).newInstance(); } catch (Exception e) { throw new Exception("Can't find class with name " + classifierString + '.'); } return evaluateModel(classifier, options); } /** * A test method for this class. Just extracts the first command line * argument as a classifier class name and calls evaluateModel. * @param args an array of command line arguments, the first of which * must be the class name of a classifier. */ public static void main(String [] args) { try { if (args.length == 0) { throw new Exception("The first argument must be the class name" + " of a classifier"); } String classifier = args[0]; args[0] = ""; System.out.println(evaluateModel(classifier, args)); } catch (Exception ex) { ex.printStackTrace(); System.err.println(ex.getMessage()); } } /** * Evaluates a classifier with the options given in an array of * strings. <p> * * Valid options are: <p> * * -t name of training file <br> * Name of the file with the training data. (required) <p> * * -T name of test file <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c class index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number of folds <br> * The number of folds for the cross-validation (default: 10). <p> * * -s random number seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m file with cost matrix <br> * The name of a file containing a cost matrix. <p> * * -l name of model input file <br> * Loads classifier from the given file. <p> * * -d name of model output file <br> * Saves classifier built from the training data into the given file. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> * * -i <br> * Outputs detailed information-retrieval statistics per class. <p> * * -k <br> * Outputs information-theoretic statistics. <p> * * -p <br> * Outputs predictions for test instances (and nothing else). <p> * * -r <br> * Outputs cumulative margin distribution (and nothing else). <p> * * -g <br> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * -a <br> * Runs the cross-validation in parallel. <p> * * @param classifier machine learning classifier * @param options the array of string containing the options * @exception Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateModel(Classifier classifier, String [] options) throws Exception { Instances train = null, tempTrain, test = null, template = null; int seed = 1, folds = 10, classIndex = -1; String trainFileName, testFileName, sourceClass, classIndexString, seedString, foldsString, objectInputFileName, objectOutputFileName, attributeRangeString; boolean IRstatistics = false, noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false, printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false; StringBuffer text = new StringBuffer(); BufferedReader trainReader = null, testReader = null; ObjectInputStream objectInputStream = null; Random random; CostMatrix costMatrix = null; StringBuffer schemeOptionsText = null; Range attributesToOutput = null; long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0; // Added by Sebastian Celis for parallelization boolean runInParallel; String parallelConfigLocation; File parallelConfigFile; StringBuffer otherComputers = new StringBuffer(); try { // Added by Sebastian Celis for parallelization runInParallel = Utils.getFlag('a', options); if(runInParallel) { parallelConfigLocation = System.getProperty("user.home"); // If the user is running windows... if(System.getProperty("os.name").charAt(0) == 'W') { parallelConfigLocation = parallelConfigLocation.concat("\\.weka-parallel"); } // If the user is running anything else else { parallelConfigLocation = parallelConfigLocation.concat("/.weka-parallel"); } parallelConfigFile = new File(parallelConfigLocation); if(!parallelConfigFile.exists()) { throw new Exception("Config file for parallelization does "+ "not exist."); } } // Get basic options (options the same for all schemes) classIndexString = Utils.getOption('c', options); if (classIndexString.length() != 0) { classIndex = Integer.parseInt(classIndexString); } trainFileName = Utils.getOption('t', options); objectInputFileName = Utils.getOption('l', options); objectOutputFileName = Utils.getOption('d', options); testFileName = Utils.getOption('T', options); if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object "+ "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test "+ "file given."); } } else if ((objectInputFileName.length() != 0) && ((!(classifier instanceof UpdateableClassifier)) || (testFileName.length() == 0))) { throw new Exception("Classifier not incremental, or no " + "test file provided: can't "+ "use both train and model file."); } try { if (trainFileName.length() != 0) { trainReader = new BufferedReader(new FileReader(trainFileName)); } if (testFileName.length() != 0) { testReader = new BufferedReader(new FileReader(testFileName)); } if (objectInputFileName.length() != 0) { InputStream is = new FileInputStream(objectInputFileName); if (objectInputFileName.endsWith(".gz")) { is = new GZIPInputStream(is); } objectInputStream = new ObjectInputStream(is); } } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } if (testFileName.length() != 0) { template = test = new Instances(testReader, 1); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { test.setClassIndex(test.numAttributes() - 1); } if (classIndex > test.numAttributes()) { throw new Exception("Index of class attribute too large."); } } if (trainFileName.length() != 0) { if ((classifier instanceof UpdateableClassifier) && (testFileName.length() != 0)) { train = new Instances(trainReader, 1); } else { train = new Instances(trainReader); } template = train; if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { train.setClassIndex(train.numAttributes() - 1); } if (classIndex > train.numAttributes()) { throw new Exception("Index of class attribute too large."); } //train = new Instances(train); } if (template == null) { throw new Exception("No actual dataset provided to use as template"); } seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); } costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses()); classStatistics = Utils.getFlag('i', options); noOutput = Utils.getFlag('o', options); trainStatistics = !Utils.getFlag('v', options); printComplexityStatistics = Utils.getFlag('k', options); printMargins = Utils.getFlag('r', options); printGraph = Utils.getFlag('g', options); sourceClass = Utils.getOption('z', options); printSource = (sourceClass.length() != 0); // Check -p option try { attributeRangeString = Utils.getOption('p', options); } catch (Exception e) { throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " + "It now expects a parameter specifying a range of attributes " + "to list with the predictions. Use '-p 0' for none."); } if (attributeRangeString.length() != 0) { printClassifications = true; if (!attributeRangeString.equals("0")) attributesToOutput = new Range(attributeRangeString); } // If a model file is given, we can't process // scheme-specific options if (objectInputFileName.length() != 0) { Utils.checkForRemainingOptions(options); } else { // Set options for classifier if (classifier instanceof OptionHandler) { for (int i = 0; i < options.length; i++) { if (options[i].length() != 0) { if (schemeOptionsText == null) { schemeOptionsText = new StringBuffer(); } if (options[i].indexOf(' ') != -1) { schemeOptionsText.append('"' + options[i] + "\" "); } else { schemeOptionsText.append(options[i] + " "); } } } ((OptionHandler)classifier).setOptions(options); } } Utils.checkForRemainingOptions(options); } catch (Exception e) { throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier)); } // Setup up evaluation objects Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix); Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix); if (objectInputFileName.length() != 0) { // Load classifier from file classifier = (Classifier) objectInputStream.readObject(); objectInputStream.close();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -