📄 vectors2classify.java
字号:
" which currently defaults to INFO level (3)", null); static CommandOption.Boolean noOverwriteProgressMessagesOption = new CommandOption.Boolean (Vectors2Classify.class, "noOverwriteProgressMessages", "true|false", false, false, "Suppress writing-in-place on terminal for progess messages - repetitive messages " +"of which only the latest is generally of interest", null); public static void main (String[] args) throws bsh.EvalError, java.io.IOException { // Process the command-line options CommandOption.setSummary (Vectors2Classify.class, "A tool for training, saving and printing diagnostics from a classifier on vectors."); CommandOption.process (Vectors2Classify.class, args); // handle default trainer here for now; default argument processing doesn't work if (!trainerConstructor.wasInvoked()){ classifierTrainers.add (new NaiveBayesTrainer()); } if (!report.wasInvoked()){ report.postParsing(null); // force postprocessing of default value } int verbosity = verbosityOption.value; Logger rootLogger = ((MalletLogger)progressLogger).getRootLogger(); if (verbosityOption.wasInvoked()){ rootLogger.setLevel( MalletLogger.LoggingLevels[verbosity]); } if (noOverwriteProgressMessagesOption.value == false){ // install special formatting for progress messages // find console handler on root logger; change formatter to one // that knows about progress messages Handler[] handlers = rootLogger.getHandlers(); for (int i = 0; i < handlers.length; i++) { if (handlers[i] instanceof ConsoleHandler) { handlers[i].setFormatter(new ProgressMessageLogFormatter()); } } } boolean separateIlists = testFile.wasInvoked() || trainingFile.wasInvoked() || validationFile.wasInvoked(); InstanceList ilist=null; InstanceList testFileIlist=null; InstanceList trainingFileIlist=null; InstanceList validationFileIlist=null; if (!separateIlists) { // normal case, --input-file specified // Read in the InstanceList, from stdin if the input filename is "-". ilist = InstanceList.load (new File(inputFile.value)); }else{ // user specified separate files for testing and training sets. trainingFileIlist = InstanceList.load (new File(trainingFile.value)); logger.info("Training vectors loaded from " + trainingFile.value); if (testFile.wasInvoked()){ testFileIlist = InstanceList.load (new File(testFile.value)); logger.info("Testing vectors loaded from " + testFile.value); } if (validationFile.wasInvoked()){ validationFileIlist = InstanceList.load (new File(validationFile.value)); logger.info("validation vectors loaded from " + validationFile.value); } } int numTrials = numTrialsOption.value; Random r = randomSeedOption.wasInvoked() ? new Random (randomSeedOption.value) : new Random (); ClassifierTrainer[] trainers = new ClassifierTrainer[classifierTrainers.size()]; for (int i = 0; i < classifierTrainers.size(); i++) { trainers[i] = (ClassifierTrainer) classifierTrainers.get(i); logger.fine ("Trainer specified = "+trainers[i].toString()); } double trainAccuracy[][] = new double[trainers.length][numTrials]; double testAccuracy[][] = new double[trainers.length][numTrials]; double validationAccuracy[][] = new double[trainers.length][numTrials]; String trainConfusionMatrix[][] = new String[trainers.length][numTrials]; String testConfusionMatrix[][] = new String[trainers.length][numTrials]; String validationConfusionMatrix[][] = new String[trainers.length][numTrials]; double t = trainingProportionOption.value; double v = validationProportionOption.value; if (!separateIlists) { logger.info("Training portion = " + t); logger.info("Validation portion = " + v); logger.info("Testing portion = " + (1 - v - t)); }// for (int i=0; i<3; i++){// for (int j=0; j<4; j++){// System.out.print(" " + ReportOptions[i][j]);// }// System.out.println();// } for (int trialIndex = 0; trialIndex < numTrials; trialIndex++) { System.out.println("\n-------------------- Trial " + trialIndex + " --------------------\n"); InstanceList[] ilists; if (!separateIlists){ ilists = ilist.split (r, new double[] {t, 1-t-v, v}); } else { ilists = new InstanceList[3]; ilists[0] = trainingFileIlist; ilists[1] = testFileIlist; ilists[2] = testFileIlist; } //InfoGain ig = new InfoGain (ilists[0]); //int igl = Math.min (10, ig.numLocations()); //for (int i = 0; i < igl; i++) //System.out.println ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i)); //ig.print(); //FeatureSelection selectedFeatures = new FeatureSelection (ig, 8000); //ilists[0].setFeatureSelection (selectedFeatures); //OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer (ilists[0]); //orfi.induceFeatures (ilists[0], false, true); //System.out.println ("Training with "+ilists[0].size()+" instances"); long time[] = new long[trainers.length]; for (int c = 0; c < trainers.length; c++){ time[c] = System.currentTimeMillis(); System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " with "+ilists[0].size()+" instances"); Classifier classifier = trainers[c].train (ilists[0]); System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " finished"); time[c] = System.currentTimeMillis() - time[c]; Trial trainTrial = new Trial (classifier, ilists[0]); //assert (ilists[1].size() > 0); Trial testTrial = new Trial (classifier, ilists[1]); Trial validationTrial = new Trial(classifier, ilists[2]); if (ilists[0].size()>0) trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix (trainTrial).toString(); if (ilists[1].size()>0) testConfusionMatrix[c][trialIndex] = new ConfusionMatrix (testTrial).toString(); if (ilists[2].size()>0) validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix (validationTrial).toString(); trainAccuracy[c][trialIndex] = trainTrial.accuracy(); testAccuracy[c][trialIndex] = testTrial.accuracy(); validationAccuracy[c][trialIndex] = validationTrial.accuracy(); if (outputFile.wasInvoked()) { String filename = outputFile.value; if (trainers.length > 1) filename = filename+trainers[c].toString(); if (numTrials > 1) filename = filename+".trial"+trialIndex; try { ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream (filename)); oos.writeObject (classifier); oos.close(); } catch (Exception e) { e.printStackTrace(); throw new IllegalArgumentException ("Couldn't write classifier to filename "+ filename); } } // New Reporting // raw output if (ReportOptions[ReportOption.train][ReportOption.raw]){ System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString()); System.out.println(" Raw Training Data"); printTrialClassification(trainTrial); } if (ReportOptions[ReportOption.test][ReportOption.raw]){ System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString()); System.out.println(" Raw Testing Data"); printTrialClassification(testTrial); } if (ReportOptions[ReportOption.validation][ReportOption.raw]){ System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString()); System.out.println(" Raw Validation Data"); printTrialClassification(validationTrial); } //train if (ReportOptions[ReportOption.train][ReportOption.confusion]){ System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Training Data Confusion Matrix"); if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]); } if (ReportOptions[ReportOption.train][ReportOption.accuracy]){ System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data accuracy= "+ trainAccuracy[c][trialIndex]); } if (ReportOptions[ReportOption.train][ReportOption.f1]){ String label = ReportOptionArgs[ReportOption.train][ReportOption.f1]; System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data F1(" + label + ") = "+ trainTrial.labelF1(label)); } //validation if (ReportOptions[ReportOption.validation][ReportOption.confusion]){ System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Validation Data Confusion Matrix"); if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]); } if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){ System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data accuracy= "+ validationAccuracy[c][trialIndex]); } if (ReportOptions[ReportOption.validation][ReportOption.f1]){ String label = ReportOptionArgs[ReportOption.validation][ReportOption.f1]; System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data F1(" + label + ") = "+ validationTrial.labelF1(label)); } //test if (ReportOptions[ReportOption.test][ReportOption.confusion]){ System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Test Data Confusion Matrix"); if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]); } if (ReportOptions[ReportOption.test][ReportOption.accuracy]){ System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data accuracy= "+ testAccuracy[c][trialIndex]); } if (ReportOptions[ReportOption.test][ReportOption.f1]){ String label = ReportOptionArgs[ReportOption.test][ReportOption.f1]; System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data F1(" + label + ") = "+ testTrial.labelF1(label)); } } // end for each trainer } // end for each trial // New reporting //"[train|test|validation]:[accuracy|f1|confusion|raw]" for (int c=0; c < trainers.length; c++) { System.out.println ("\n"+trainers[c].toString()); if (ReportOptions[ReportOption.train][ReportOption.accuracy]) System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+ " stddev = "+ MatrixOps.stddev (trainAccuracy[c])+ " stderr = "+ MatrixOps.stderr (trainAccuracy[c])); if (ReportOptions[ReportOption.validation][ReportOption.accuracy]) System.out.println ("Summary. validation accuracy mean = "+ MatrixOps.mean (validationAccuracy[c])+ " stddev = "+ MatrixOps.stddev (validationAccuracy[c])+ " stderr = "+ MatrixOps.stderr (validationAccuracy[c])); if (ReportOptions[ReportOption.test][ReportOption.accuracy]) System.out.println ("Summary. test accuracy mean = "+ MatrixOps.mean (testAccuracy[c])+ " stddev = "+ MatrixOps.stddev (testAccuracy[c])+ " stderr = "+ MatrixOps.stderr (testAccuracy[c])); } // end for each trainer } private static void printTrialClassification(Trial trial) { ArrayList classifications = trial.toArrayList(); for (int i = 0; i < classifications.size(); i++) { Instance instance = trial.getClassification(i).getInstance(); System.out.print(instance.getName() + " " + instance.getTarget() + " "); Labeling labeling = trial.getClassification(i).getLabeling(); for (int j = 0; j < labeling.numLocations(); j++){ System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " "); } System.out.println(); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -