⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 evaluation.java

📁 一个数据挖掘系统的源码
💻 JAVA
📖 第 1 页 / 共 5 页
字号:

    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
      OutputStream os = new FileOutputStream(objectOutputFileName);

      if (objectOutputFileName.endsWith(".gz")) {
        os = new GZIPOutputStream(os);
      }
      ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);

      objectOutputStream.writeObject(classifier);
      objectOutputStream.flush();
      objectOutputStream.close();

    }

    // If classifier is drawable output string describing graph
    if ((classifier instanceof Drawable)
	&& (printGraph)){
      return ((Drawable)classifier).graph();
    }

    // Output the classifier as equivalent source
    if ((classifier instanceof Sourcable)
	&& (printSource)){
      return wekaStaticWrapper((Sourcable) classifier, sourceClass);
    }

    // Output test instance predictions only
    if (printClassifications) {
      return printClassifications(classifier, new Instances(template, 0),
				  testFileName, classIndex, attributesToOutput);
    }

    // Output model
    if (!(noOutput || printMargins)) {
      if (classifier instanceof OptionHandler) {
	if (schemeOptionsText != null) {
	  text.append("\nOptions: "+schemeOptionsText);
	  text.append("\n");
	}
      }
      text.append("\n" + classifier.toString() + "\n");
    }

    if (!printMargins && (costMatrix != null)) {
      text.append("\n=== Evaluation Cost Matrix ===\n\n")
        .append(costMatrix.toString());
    }

    // Compute error estimate from training data
    if ((trainStatistics) &&
	(trainFileName.length() != 0)) {

      if ((classifier instanceof UpdateableClassifier) &&
	  (testFileName.length() != 0) &&
	  (costMatrix == null)) {

	// Classifier was trained incrementally, so we have to
	// reopen the training data in order to test on it.
	trainReader = new BufferedReader(new FileReader(trainFileName));

	// Incremental testing
	train = new Instances(trainReader, 1);
	if (classIndex != -1) {
	  train.setClassIndex(classIndex - 1);
	} else {
	  train.setClassIndex(train.numAttributes() - 1);
	}
	testTimeStart = System.currentTimeMillis();
	while (train.readInstance(trainReader)) {

	  trainingEvaluation.
	  evaluateModelOnce((Classifier)classifier,
			    train.instance(0));
	  train.delete(0);
	}
	testTimeElapsed = System.currentTimeMillis() - testTimeStart;
	trainReader.close();
      } else {
	testTimeStart = System.currentTimeMillis();
	trainingEvaluation.evaluateModel(classifier,
					 train);
	testTimeElapsed = System.currentTimeMillis() - testTimeStart;
      }

      // Print the results of the training evaluation
      if (printMargins) {
	return trainingEvaluation.toCumulativeMarginDistributionString();
      } else {
	text.append("\nTime taken to build model: " +
		    Utils.doubleToString(trainTimeElapsed / 1000.0,2) +
		    " seconds");
	text.append("\nTime taken to test model on training data: " +
		    Utils.doubleToString(testTimeElapsed / 1000.0,2) +
		    " seconds");
	text.append(trainingEvaluation.
		    toSummaryString("\n\n=== Error on training" +
				    " data ===\n", printComplexityStatistics));
	if (template.classAttribute().isNominal()) {
	  if (classStatistics) {
	    text.append("\n\n" + trainingEvaluation.toClassDetailsString());
	  }
	  text.append("\n\n" + trainingEvaluation.toMatrixString());
	}

      }
    }

    // Compute proper error estimates
    if (testFileName.length() != 0) {

      // Testing is on the supplied test data
      while (test.readInstance(testReader)) {

	testingEvaluation.evaluateModelOnce((Classifier)classifier,
                                            test.instance(0));
	test.delete(0);
      }
      testReader.close();

      text.append("\n\n" + testingEvaluation.
		  toSummaryString("=== Error on test data ===\n",
				  printComplexityStatistics));
    } else if (trainFileName.length() != 0) {

      // Testing is via cross-validation on training data
      random = new Random(seed);
      random.setSeed(seed);
      train.randomize(random);
      testingEvaluation.
	crossValidateModel(classifier, train, folds);
      if (template.classAttribute().isNumeric()) {
	text.append("\n\n\n" + testingEvaluation.
		    toSummaryString("=== Cross-validation ===\n",
				    printComplexityStatistics));
      } else {
	text.append("\n\n\n" + testingEvaluation.
		    toSummaryString("=== Stratified " +
				    "cross-validation ===\n",
				    printComplexityStatistics));
      }
    }
    if (template.classAttribute().isNominal()) {
      if (classStatistics) {
	text.append("\n\n" + testingEvaluation.toClassDetailsString());
      }
      text.append("\n\n" + testingEvaluation.toMatrixString());
    }
    return text.toString();
  }


  public static String evaluateModelPercentage(Classifier classifier,
				     String [] options, int percentage) throws Exception {

    Instances train = null, tempTrain, test = null, template = null, trainTemp = null;
    int seed = 1, folds = 10, classIndex = -1, percent = 0, sizeOfTrainFile = 0, sizeOfTestFile = 0;
    String trainFileName, sourceClass,
      classIndexString, seedString, foldsString, objectInputFileName,
      objectOutputFileName, attributeRangeString;
    boolean IRstatistics = false, noOutput = false,
      printClassifications = false, trainStatistics = true,
      printMargins = false, printComplexityStatistics = false,
      printGraph = false, classStatistics = false, printSource = false;
    StringBuffer text = new StringBuffer();
    BufferedReader trainReader = null;
    ObjectInputStream objectInputStream = null;
    Random random;
    CostMatrix costMatrix = null;
    StringBuffer schemeOptionsText = null;
    Range attributesToOutput = null;
    long trainTimeStart = 0, trainTimeElapsed = 0,
      testTimeStart = 0, testTimeElapsed = 0;

    try {
      percent = percentage;
      // Get basic options (options the same for all schemes)
      classIndexString = Utils.getOption('c', options);
      if (classIndexString.length() != 0) {
	classIndex = Integer.parseInt(classIndexString);
      }
      trainFileName = Utils.getOption('t', options);
      objectInputFileName = Utils.getOption('l', options);
      objectOutputFileName = Utils.getOption('d', options);
      if (trainFileName.length() == 0) {
	if (objectInputFileName.length() == 0) {
	  throw new Exception("No training file and no object "+
			      "input file given.");
	}
      } else if ((objectInputFileName.length() != 0) &&
		 ((!(classifier instanceof UpdateableClassifier)) )) {
	throw new Exception("Classifier not incremental, or no " +
			    "test file provided: can't "+
			    "use both train and model file.");
      }
      try {
	if (trainFileName.length() != 0) {
	  trainReader = new BufferedReader(new FileReader(trainFileName));
	}
	if (objectInputFileName.length() != 0) {
          InputStream is = new FileInputStream(objectInputFileName);
          if (objectInputFileName.endsWith(".gz")) {
            is = new GZIPInputStream(is);
          }
	  objectInputStream = new ObjectInputStream(is);
	}
      }
      catch (Exception e) {
	throw new Exception("Can't open file " + e.getMessage() + '.');
      }

      if (trainFileName.length() != 0) {
	if ((classifier instanceof UpdateableClassifier)) {
          trainTemp = new Instances (trainReader);
          sizeOfTrainFile = trainTemp.numInstances() * percent / 100;
          sizeOfTestFile = trainTemp.numInstances() - sizeOfTrainFile;
	  train = new Instances(trainTemp, 1, sizeOfTrainFile);
          test = new Instances (trainTemp, sizeOfTrainFile,sizeOfTestFile );
	}
        else {
          trainTemp = new Instances (trainReader);
          sizeOfTrainFile = trainTemp.numInstances() * percent / 100;
          sizeOfTestFile = trainTemp.numInstances() - sizeOfTrainFile;
	  train = new Instances(trainTemp, 0, sizeOfTrainFile);
          test = new Instances (trainTemp, sizeOfTrainFile,sizeOfTestFile );
	}
        template = train;
	if (classIndex != -1) {
	  train.setClassIndex(classIndex - 1);
	}
        else {
	  train.setClassIndex(train.numAttributes() - 1);
	}
	if (classIndex > train.numAttributes()) {
	  throw new Exception("Index of class attribute too large.");
	}
	if (classIndex != -1) {
	  test.setClassIndex(classIndex - 1);
	}
        else {
	  test.setClassIndex(test.numAttributes() - 1);
	}
	if (classIndex > test.numAttributes()) {
	  throw new Exception("Index of class attribute too large.");
	}
        //train = new Instances(train);
      }
      if (template == null) {
        throw new Exception("No actual dataset provided to use as template");
      }
      seedString = Utils.getOption('s', options);
      if (seedString.length() != 0) {
	seed = Integer.parseInt(seedString);
      }
      foldsString = Utils.getOption('x', options);
      if (foldsString.length() != 0) {
	folds = Integer.parseInt(foldsString);
      }
      costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses());

      classStatistics = Utils.getFlag('i', options);
      noOutput = Utils.getFlag('o', options);
      trainStatistics = !Utils.getFlag('v', options);
      printComplexityStatistics = Utils.getFlag('k', options);
      printMargins = Utils.getFlag('r', options);
      printGraph = Utils.getFlag('g', options);
      sourceClass = Utils.getOption('z', options);
      printSource = (sourceClass.length() != 0);

      // Check -p option
      try {
	attributeRangeString = Utils.getOption('p', options);
      }
      catch (Exception e) {
	throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
			    "It now expects a parameter specifying a range of attributes " +
			    "to list with the predictions. Use '-p 0' for none.");
      }
      if (attributeRangeString.length() != 0) {
	printClassifications = true;
	if (!attributeRangeString.equals("0"))
	  attributesToOutput = new Range(attributeRangeString);
      }

      // If a model file is given, we can't process
      // scheme-specific options
      if (objectInputFileName.length() != 0) {
	Utils.checkForRemainingOptions(options);
      } else {

	// Set options for classifier
	if (classifier instanceof OptionHandler) {
	  for (int i = 0; i < options.length; i++) {
	    if (options[i].length() != 0) {
	      if (schemeOptionsText == null) {
		schemeOptionsText = new StringBuffer();
	      }
	      if (options[i].indexOf(' ') != -1) {
		schemeOptionsText.append('"' + options[i] + "\" ");
	      } else {
		schemeOptionsText.append(options[i] + " ");
	      }
	    }
	  }
	  ((OptionHandler)classifier).setOptions(options);
	}
      }
      Utils.checkForRemainingOptions(options);
    } catch (Exception e) {
      throw new Exception("\nWeka exception: " + e.getMessage()
			   + makeOptionString(classifier));
    }


    // Setup up evaluation objects
    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);

    if (objectInputFileName.length() != 0) {

      // Load classifier from file
      classifier = (Classifier) objectInputStream.readObject();
      objectInputStream.close();
    }

    // Build the classifier if no object file provided
    if ((classifier instanceof UpdateableClassifier) &&
	(costMatrix == null) &&
	(trainFileName.length() != 0)) {

      // Build classifier incrementally
      trainingEvaluation.setPriors(train);
      testingEvaluation.setPriors(train);
      trainTimeStart = System.currentTimeMillis();
      if (objectInputFileName.length() == 0) {
	classifier.buildClassifier(train);
      }
      while (train.readInstance(trainReader)) {

	trainingEvaluation.updatePriors(train.instance(0));
	testingEvaluation.updatePriors(train.instance(0));
	((UpdateableClassifier)classifier).
	  updateClassifier(train.instance(0));
	train.delete(0);
      }
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
      trainReader.close();
    } else if (objectInputFileName.length() == 0) {

      // Build classifier in one go
      tempTrain = new Instances(train);
      trainingEvaluation.setPriors(tempTrain);
      testingEvaluation.setPriors(tempTrain);
      trainTimeStart = System.currentTimeMillis();
      classifier.buildClassifier(tempTrain);
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    }

    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
      OutputStream os = new FileOutputStream(objectOutputFileName);
      if (objectOutputFileName.endsWith(".gz")) {
        os = new GZIPOutputStream(os);
      }
      ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
      objectOutputStream.writeObject(classifier);
      objectOutputStream.flush();
      objectOutputStream.close();
    }

    // If classifier is drawable output string describing graph
    if ((classifier instanceof Drawable)
	&& (printGraph)){
      return ((Drawable)classifier).graph();

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -