⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 dependencyparser.java

📁 MSTParser是以最大生成树理论为基础的判别式依存句法分析器。它将一科依存树的得分看作是 所有依存关系的得分的总和
💻 JAVA
字号:
package mstparser;import java.io.*;import java.util.ArrayList;import java.util.Arrays;import java.util.StringTokenizer;import gnu.trove.*;public class DependencyParser {    public static String trainfile = null;    public static String testfile = null;    public static String trainforest = null;    public static String testforest = null;    public static boolean train = false;    public static boolean eval = false;    public static boolean test = false;    public static String modelName = "dep.model";    public static String lossType = "punc";    public static boolean createForest = true;    public static String decodeType = "proj";    public static int numIters = 10;    public static String outfile = "out.txt";    public static String goldfile = null;    public static int trainK = 1;    public static int testK = 1;    public static boolean secondOrder = false;    private DependencyPipe pipe;    private DependencyDecoder decoder;    private Parameters params;    public DependencyParser(DependencyPipe pipe) {	this.pipe=pipe;	// Set up arrays	params = new Parameters(pipe.dataAlphabet.size());	decoder = secondOrder ? new DependencyDecoder2O(pipe) : new DependencyDecoder(pipe);    }    public void train(DependencyInstance[] il, String trainfile, String train_forest) throws IOException {			System.out.println("About to train");	System.out.println("Num Feats: " + pipe.dataAlphabet.size());			int i = 0;	for(i = 0; i < numIters; i++) {				    System.out.println("========================");	    System.out.println("Iteration: " + i);	    System.out.println("========================");	    System.out.print("Processed: ");	    long start = System.currentTimeMillis();	    trainingIter(il,trainfile,train_forest,i+1);	    long end = System.currentTimeMillis();	    System.out.println("Training iter took: " + (end-start));				}	params.averageParams(i*il.length);		    }    private void trainingIter(DependencyInstance[] il, String trainfile, String train_forest, int iter) throws IOException {	int numUpd = 0;	ObjectInputStream in = new ObjectInputStream(new FileInputStream(train_forest));	boolean evaluateI = true;	for(int i = 0; i < il.length; i++) {	    if((i+1) % 500 == 0)		System.out.println("  "+(i+1)+" instances");	    DependencyInstance inst = il[i];			    int length = inst.length;	    // Get production crap.	    FeatureVector[][][] fvs = new FeatureVector[length][length][2];	    double[][][] probs = new double[length][length][2];	    FeatureVector[][][][] nt_fvs = new FeatureVector[length][pipe.types.length][2][2];	    double[][][][] nt_probs = new double[length][pipe.types.length][2][2];	    FeatureVector[][][] fvs_trips = new FeatureVector[length][length][length];	    double[][][] probs_trips = new double[length][length][length];	    FeatureVector[][][] fvs_sibs = new FeatureVector[length][length][2];	    double[][][] probs_sibs = new double[length][length][2];	    if(secondOrder)		inst = ((DependencyPipe2O)pipe).getFeatureVector(in,inst,fvs,probs,								 fvs_trips,probs_trips,								 fvs_sibs,probs_sibs,								 nt_fvs,nt_probs,params);	    else		inst = pipe.getFeatureVector(in,inst,fvs,probs,nt_fvs,nt_probs,params);	    double upd = (double)(numIters*il.length - (il.length*(iter-1)+(i+1)) + 1);	    int K=trainK;	    Object[][] d = null;	    if(decodeType.equals("proj")) {		if(secondOrder)		    d = ((DependencyDecoder2O)decoder).decodeProjective(inst,fvs,probs,									fvs_trips,probs_trips,									fvs_sibs,probs_sibs,									nt_fvs,nt_probs,K);		else		    d = decoder.decodeProjective(inst,fvs,probs,nt_fvs,nt_probs,K);	    }	    if(decodeType.equals("non-proj")) {		if(secondOrder)		    d = ((DependencyDecoder2O)decoder).decodeNonProjective(inst,fvs,probs,								       fvs_trips,probs_trips,								       fvs_sibs,probs_sibs,								       nt_fvs,nt_probs,K);		else		    d = decoder.decodeNonProjective(inst,fvs,probs,nt_fvs,nt_probs,K);	    }	    params.updateParamsMIRA(inst,d,upd);	}	System.out.println("");		System.out.println("  "+il.length+" instances");			in.close();    }    ///////////////////////////////////////////////////////    // Saving and loading models    ///////////////////////////////////////////////////////    public void saveModel(String file) throws IOException {	ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(file));	out.writeObject(params.parameters);	out.writeObject(pipe.dataAlphabet);	out.writeObject(pipe.typeAlphabet);	out.close();    }    public void loadModel(String file) throws Exception {	ObjectInputStream in = new ObjectInputStream(new FileInputStream(file));	params.parameters = (double[])in.readObject();	pipe.dataAlphabet = (Alphabet)in.readObject();	pipe.typeAlphabet = (Alphabet)in.readObject();	in.close();	pipe.closeAlphabets();    }    //////////////////////////////////////////////////////    // Get Best Parses ///////////////////////////////////    //////////////////////////////////////////////////////    public void outputParses(String tFile, String file)	throws IOException {	long start = System.currentTimeMillis();	BufferedWriter pred = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),"UTF8"));	BufferedReader in =	    new BufferedReader(new InputStreamReader(new FileInputStream(tFile),"UTF8"));	System.out.print("Processing Sentence: ");	DependencyInstance il = pipe.createInstance(in);	int cnt = 0;	while(il != null) {	    cnt++;	    System.out.print(cnt+" ");	    String[] toks = il.sentence;				    int length = toks.length;	    FeatureVector[][][] fvs = new FeatureVector[toks.length][toks.length][2];	    double[][][] probs = new double[toks.length][toks.length][2];	    FeatureVector[][][][] nt_fvs = new FeatureVector[toks.length][pipe.types.length][2][2];	    double[][][][] nt_probs = new double[toks.length][pipe.types.length][2][2];	    FeatureVector[][][] fvs_trips = new FeatureVector[length][length][length];	    double[][][] probs_trips = new double[length][length][length];	    FeatureVector[][][] fvs_sibs = new FeatureVector[length][length][2];	    double[][][] probs_sibs = new double[length][length][2];	    if(secondOrder)		((DependencyPipe2O)pipe).getFeatureVector(il,fvs,probs,							  fvs_trips,probs_trips,							  fvs_sibs,probs_sibs,							  nt_fvs,nt_probs,params);	    else		pipe.getFeatureVector(il,fvs,probs,nt_fvs,nt_probs,params);	    int K = testK;	    Object[][] d = null;	    if(decodeType.equals("proj")) {		if(secondOrder)		    d = ((DependencyDecoder2O)decoder).decodeProjective(il,fvs,probs,									fvs_trips,probs_trips,									fvs_sibs,probs_sibs,									nt_fvs,nt_probs,K);		else		    d = decoder.decodeProjective(il,fvs,probs,nt_fvs,nt_probs,K);	    }	    if(decodeType.equals("non-proj")) {		if(secondOrder)		    d = ((DependencyDecoder2O)decoder).decodeNonProjective(il,fvs,probs,								       fvs_trips,probs_trips,								       fvs_sibs,probs_sibs,								       nt_fvs,nt_probs,K);		else		    d = decoder.decodeNonProjective(il,fvs,probs,nt_fvs,nt_probs,K);	    }	    String[] res = ((String)d[0][1]).split(" ");	    String[] sent = il.sentence;	    String[] pos = il.pos;	    String line1 = ""; String line2 = ""; String line3 = ""; String line4 = "";	    for(int j = 1; j < pos.length; j++) {		String[] trip = res[j-1].split("[\\|:]");		line1+= sent[j] + "\t"; line2 += pos[j] + "\t";		line4 += trip[0] + "\t"; line3 += pipe.types[Integer.parseInt(trip[2])] + "\t";	    }	    pred.write(line1.trim() + "\n" + line2.trim() + "\n"		       + (pipe.labeled ? line3.trim() + "\n" : "")		       + line4.trim() + "\n\n");	    	    il = pipe.createInstance(in);	}	System.out.println();			pred.close();	in.close();			long end = System.currentTimeMillis();	System.out.println("Took: " + (end-start));    }    /////////////////////////////////////////////////////    // RUNNING THE PARSER    ////////////////////////////////////////////////////    public static void main (String[] args) throws FileNotFoundException, Exception    {			processArguments(args);			if(train) {			    DependencyPipe pipe =		secondOrder ? new DependencyPipe2O (createForest) : new DependencyPipe (createForest);				    pipe.setLabeled(trainfile);	    DependencyInstance[] trainingData = pipe.createInstances(trainfile,trainforest);				    pipe.closeAlphabets();				    DependencyParser dp = new DependencyParser(pipe);	    int numFeats = pipe.dataAlphabet.size();	    int numTypes = pipe.typeAlphabet.size();	    System.out.println("Num Feats: " + numFeats);		    System.out.println("Num Edge Labels: " + numTypes);	    dp.train(trainingData,trainfile,trainforest);		    System.out.print("Saving model ... ");	    dp.saveModel(modelName);	    System.out.println("done.");				}			if (test) {	    DependencyPipe pipe =		secondOrder ? new DependencyPipe2O (true) : new DependencyPipe (true);	    pipe.setLabeled(testfile);	    DependencyParser dp = new DependencyParser(pipe);	    System.out.println("\nLoading model ... ");	    dp.loadModel(modelName);	    System.out.println("done.");	    pipe.closeAlphabets();	    dp.outputParses(testfile,outfile);	}			if(eval) {	    System.out.println("\nEVALUATION PERFORMANCE:");	    DependencyEvaluator.evaluate(goldfile,outfile);	}    }    public static void processArguments(String[] args) {	for(int i = 0; i < args.length; i++) {	    String[] pair = args[i].split(":");	    if(pair[0].equals("train")) {		train = true;	    }	    if(pair[0].equals("eval")) {		eval = true;	    }	    if(pair[0].equals("test")) {		test = true;	    }	    if(pair[0].equals("iters")) {		numIters = Integer.parseInt(pair[1]);	    }	    if(pair[0].equals("output-file")) {		outfile = pair[1];	    }	    if(pair[0].equals("gold-file")) {		goldfile = pair[1];	    }	    if(pair[0].equals("train-file")) {		trainfile = pair[1];	    }	    if(pair[0].equals("test-file")) {		testfile = pair[1];	    }	    if(pair[0].equals("model-name")) {		modelName = pair[1];	    }	    if(pair[0].equals("training-k")) {		trainK = Integer.parseInt(pair[1]);	    }	    if(pair[0].equals("loss-type")) {		lossType = pair[1];	    }				    if(pair[0].equals("order") && pair[1].equals("2")) {		secondOrder = true;	    }				    if(pair[0].equals("create-forest")) {		createForest = pair[1].equals("true") ? true : false;	    }				    if(pair[0].equals("decode-type")) {		decodeType = pair[1];	    }				}	trainforest = trainfile == null ? null : trainfile+".forest";	testforest = testfile == null ? null : testfile+".forest";		System.out.println("------\nFLAGS\n------");	System.out.println("train-file: " + trainfile);	System.out.println("test-file: " + testfile);	System.out.println("gold-file: " + goldfile);	System.out.println("output-file: " + outfile);	System.out.println("model-name: " + modelName);	System.out.println("train: " + train);	System.out.println("test: " + test);	System.out.println("eval: " + eval);	System.out.println("loss-type: " + lossType);	System.out.println("second-order: " + secondOrder);	System.out.println("training-iterations: " + numIters);	System.out.println("training-k: " + trainK);	System.out.println("decode-type: " + decodeType);	System.out.println("create-forest: " + createForest);	System.out.println("------\n");    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -