📄 vectors2classify.java
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.classify.tui;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.classify.*;import edu.umass.cs.mallet.base.classify.evaluate.*;import edu.umass.cs.mallet.base.util.*;import edu.umass.cs.mallet.base.util.CommandOption;import java.io.*;import java.util.*;import java.util.Random;import java.util.logging.*;import java.lang.reflect.*;/** * Classify documents, run trials, print statistics from a vector file. @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */public abstract class Vectors2Classify{ private static Logger logger = MalletLogger.getLogger(Vectors2Classify.class.getName()); private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Vectors2Classify.class.getName() + "-pl"); private static ArrayList classifierTrainers = new ArrayList(); private static boolean[][] ReportOptions = new boolean[3][4]; private static String[][] ReportOptionArgs = new String[3][4]; //arg in dataset:reportOption=arg // Essentially an enum mapping string names to enums to ints. private static class ReportOption { static final String[] dataOptions = {"train", "test", "validation"}; static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw"}; static final int train=0; static final int test =1; static final int validation=2; static final int accuracy=0; static final int f1=1; static final int confusion=2; static final int raw=3; } static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings (Vectors2Classify.class, "report", "[train|test|validation]:[accuracy|f1|confusion|raw]", true, new String[] {"test:accuracy", "test:confusion", "train:accuracy"}, "", null) { public void postParsing (CommandOption.List list) { java.lang.String defaultRawFormatting = "siw"; for (int argi=0; argi<this.value.length; argi++){ // convert options like --report train:accuracy --report test:f1=labelA to // boolean array of options. // first, split the argument at semicolon. //System.out.println(argi + " " + this.value[argi]); java.lang.String arg = this.value[argi]; java.lang.String fields[] = arg.split("[:=]"); java.lang.String dataSet = fields[0]; java.lang.String reportOption = fields[1]; java.lang.String reportOptionArg = null; if (fields.length >=3){ reportOptionArg = fields[2]; } //System.out.println("Report option arg " + reportOptionArg); //find the datasource (test,train,validation) boolean foundDataSource = false; int i=0; for (; i<ReportOption.dataOptions.length; i++){ if (dataSet.equals(ReportOption.dataOptions[i])){ foundDataSource = true; break; } } if (!foundDataSource){ throw new IllegalArgumentException("Unknown argument = " + dataSet + " in --report " + this.value[argi]); } //find the report option (accuracy, f1, confusion, raw) boolean foundReportOption = false; int j=0; for (; j<ReportOption.reportOptions.length; j++){ if (reportOption.equals(ReportOption.reportOptions[j])){ foundReportOption = true; break; } } if (!foundReportOption){ throw new IllegalArgumentException("Unknown argument = " + reportOption + " in --report " + this.value[argi]); } //Mark the (dataSet,reportionOption) pair as selected ReportOptions[i][j] = true; if (j == ReportOption.f1){ // make sure a label was specified for f1 if (reportOptionArg == null){ throw new IllegalArgumentException("F1 must have label argument in --report " + this.value[argi]); } // Pass through the string argument ReportOptionArgs[i][j]= reportOptionArg; }else if (reportOptionArg != null){ throw new IllegalArgumentException("No arguments after = allowed in --report " + this.value[argi]); } } } }; static CommandOption.Object trainerConstructor = new CommandOption.Object (Vectors2Classify.class, "trainer", "ClassifierTrainer constructor", true, new NaiveBayesTrainer(), "Java code for the constructor used to create a ClassifierTrainer. "+ "If no '(' appears, then \"new \" will be prepended and \"Trainer()\" will be appended."+ "You may use this option mutiple times to compare multiple classifiers.", null) { public void parseArg (java.lang.String arg) { // parse something like Maxent,gaussianPriorVariance=10,numIterations=20 //System.out.println("Arg = " + arg); // first, split the argument at commas. java.lang.String fields[] = arg.split(","); //Massage constructor name, so that MaxEnt, MaxEntTrainer, new MaxEntTrainer() // all call new MaxEntTrainer() java.lang.String constructorName = fields[0]; if (constructorName.indexOf('(') != -1) // if contains (), pass it though super.parseArg(arg); else { if (constructorName.endsWith("Trainer")){ super.parseArg("new " + constructorName + "()"); // add parens if they forgot }else{ super.parseArg("new "+constructorName+"Trainer()"); // make trainer name from classifier name } } // find methods associated with the class we just built Method methods[] = this.value.getClass().getMethods(); // find setters corresponding to parameter names. for (int i=1; i<fields.length; i++){ java.lang.String nameValuePair[] = fields[i].split("="); java.lang.String parameterName = nameValuePair[0]; java.lang.String parameterValue = nameValuePair[1]; //todo: check for val present! java.lang.Object parameterValueObject; try { parameterValueObject = getInterpreter().eval(parameterValue); } catch (bsh.EvalError e) { throw new IllegalArgumentException ("Java interpreter eval error on parameter "+ parameterName + "\n"+e); } boolean foundSetter = false; for (int j=0; j<methods.length; j++){// System.out.println("method " + j + " name is " + methods[j].getName());// System.out.println("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1)); if ( ("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1)).equals(methods[j].getName()) && methods[j].getParameterTypes().length == 1){// System.out.println("Matched method " + methods[j].getName());// Class[] ptypes = methods[j].getParameterTypes();// System.out.println("Parameter types:");// for (int k=0; k<ptypes.length; k++){// System.out.println("class " + k + " = " + ptypes[k].getName());// } try { java.lang.Object[] parameterList = new java.lang.Object[]{parameterValueObject};// System.out.println("Argument types:");// for (int k=0; k<parameterList.length; k++){// System.out.println("class " + k + " = " + parameterList[k].getClass().getName());// } methods[j].invoke(this.value, parameterList); } catch ( IllegalAccessException e) { System.out.println("IllegalAccessException " + e); throw new IllegalArgumentException ("Java access error calling setter\n"+e); } catch ( InvocationTargetException e) { System.out.println("IllegalTargetException " + e); throw new IllegalArgumentException ("Java target error calling setter\n"+e); } foundSetter = true; break; } } if (!foundSetter){ System.out.println("Parameter " + parameterName + " not found on trainer " + constructorName); System.out.println("Available parameters for " + constructorName); for (int j=0; j<methods.length; j++){ if ( methods[j].getName().startsWith("set") && methods[j].getParameterTypes().length == 1){ System.out.println(Character.toLowerCase(methods[j].getName().charAt(3)) + methods[j].getName().substring(4)); } } throw new IllegalArgumentException ("no setter found for parameter " + parameterName); } } } public void postParsing (CommandOption.List list) { assert (this.value instanceof ClassifierTrainer); //System.out.println("v2c PostParsing " + this.value); classifierTrainers.add (this.value); } }; static CommandOption.String outputFile = new CommandOption.String (Vectors2Classify.class, "output-classifier", "FILENAME", true, "classifier.mallet", "The filename in which to write the classifier after it has been trained.", null); static CommandOption.String inputFile = new CommandOption.String (Vectors2Classify.class, "input", "FILENAME", true, "text.vectors", "The filename from which to read the list of training instances. Use - for stdin.", null); static CommandOption.String trainingFile = new CommandOption.String (Vectors2Classify.class, "training-file", "FILENAME", true, "text.vectors", "Read the training set instance list from this file. " + "If this is specified, the input file parameter is ignored", null); static CommandOption.String testFile = new CommandOption.String (Vectors2Classify.class, "testing-file", "FILENAME", true, "text.vectors", "Read the test set instance list to this file. " + "If this option is specified, the training-file parameter must be specified and " + " the input-file parameter is ignored", null); static CommandOption.String validationFile = new CommandOption.String (Vectors2Classify.class, "validation-file", "FILENAME", true, "text.vectors", "Read the validation set instance list to this file." + "If this option is specified, the training-file parameter must be specified and " + "the input-file parameter is ignored", null); static CommandOption.Double trainingProportionOption = new CommandOption.Double (Vectors2Classify.class, "training-portion", "DECIMAL", true, 1.0, "The fraction of the instances that should be used for training.", null); static CommandOption.Double validationProportionOption = new CommandOption.Double (Vectors2Classify.class, "validation-portion", "DECIMAL", true, 0.0, "The fraction of the instances that should be used for validation.", null); static CommandOption.Integer randomSeedOption = new CommandOption.Integer (Vectors2Classify.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting a proportion of the instance list for training", null); static CommandOption.Integer numTrialsOption = new CommandOption.Integer (Vectors2Classify.class, "num-trials", "INTEGER", true, 1, "The number of random train/test splits to perform", null); static CommandOption.Object classifierEvaluatorOption = new CommandOption.Object (Vectors2Classify.class, "classifier-evaluator", "CONSTRUCTOR", true, null, "Java code for constructing a ClassifierEvaluating object", null);// static CommandOption.Boolean printTrainAccuracyOption = new CommandOption.Boolean// (Vectors2Classify.class, "print-train-accuracy", "true|false", true, true,// "After training, run the resulting classifier on the instances included in training, "// +"and print the accuracy", null);//// static CommandOption.Boolean printTestAccuracyOption = new CommandOption.Boolean// (Vectors2Classify.class, "print-test-accuracy", "true|false", true, true,// "After training, run the resulting classifier on the instances not included in training, "// +"and print the accuracy", null); static CommandOption.Integer verbosityOption = new CommandOption.Integer (Vectors2Classify.class, "verbosity", "INTEGER", true, -1, "The level of messages to print: 0 is silent, 8 is most verbose. " + "Levels 0-8 correspond to the java.logger predefined levels "+ "off, severe, warning, info, config, fine, finer, finest, all. " + "The default value is taken from the mallet logging.properties file," +
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -