⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 segment.java

📁 CRF1.2
💻 JAVA
字号:
package iitb.Segment;import java.io.*;import java.util.*;import iitb.CRF.*;import iitb.Model.*;import iitb.Utils.*;/** * * @author Sunita Sarawagi * */ public class Segment {    String inName;    String outDir;    String baseDir="";    int nlabels;    String delimit=" \t"; // used to define token boundaries    String tagDelimit="|"; // seperator between tokens and tag number    String impDelimit=""; // delimiters to be retained for tagging    String groupDelimit=null;    boolean confuseSet[]=null;    boolean validate = false;     String mapTagString = null;    String smoothType = "";            String modelArgs = "";    String featureArgs = "";    String modelGraphType = "naive";    LabelMap labelMap;    Options options;    CRF crfModel;    FeatureGenImpl featureGen;    public FeatureGenerator featureGenerator() {return featureGen;}    public static void main(String argv[]) throws Exception {	if (argv.length < 3) {	    System.out.println("Usage: java Tagger train|test|calc -f <conf-file>");	    return;	}	Segment segment = new Segment();	segment.parseConf(argv);	if (argv[0].toLowerCase().equals("all")) {	    segment.train();	    segment.doTest();	    segment.calc();	} 	if (argv[0].toLowerCase().equals("train")) {	    segment.train();	} 	if (argv[0].toLowerCase().equals("test")) {	    segment.test();	}	if (argv[0].toLowerCase().equals("calc")) {	    segment.calc();	}    }    public void parseConf(String argv[]) throws Exception {	options = new Options();	int startIndex = 1;	if ((argv.length >= 2) && argv[1].toLowerCase().equals("-f")) {	    options.load(new FileInputStream(argv[2]));	}	options.add(3, argv);	processArgs();    }    public void processArgs() throws Exception {	String value = null;	if ((value = options.getMandatoryProperty("numlabels")) != null) {	    nlabels=Integer.parseInt(value);	}	if ((value = options.getProperty("binary")) != null) {	    nlabels = 2;	    labelMap = new BinaryLabelMap(options.getInt("binary"));	} else {	    labelMap = new LabelMap();	}	if ((value = options.getMandatoryProperty("inname")) != null) {	    inName=new String(value);	}	if ((value = options.getMandatoryProperty("outdir")) != null) {	    outDir=new String(value);	}	if ((value = options.getProperty("basedir")) != null) {	    baseDir=new String(value);	}	if ((value = options.getProperty("tagdelimit")) != null) {	    tagDelimit=new String(value);	}	// delimiters that will be ignored.	if ((value = options.getProperty("delimit")) != null) {	    delimit=new String(value);	}	if ((value = options.getProperty("impdelimit")) != null) {	    impDelimit=new String(value);	}	if ((value = options.getProperty("groupdelimit")) != null) {	    groupDelimit=value;	}	if ((value = options.getProperty("confusion")) != null) {	    StringTokenizer confuse=new StringTokenizer(value,", ");	    int confuseSize=confuse.countTokens();	    confuseSet=new boolean[nlabels+1];	    for(int i=0 ; i<confuseSize ; i++) {		confuseSet[Integer.parseInt(confuse.nextToken())]=true;	    }	}	if ((value = options.getProperty("map-tags")) != null) {	    mapTagString = value;	}	if ((value = options.getProperty("validate")) != null) {	    validate = true;	}	if ((value = options.getProperty("model-args")) != null) {	    modelArgs = value;	    System.out.println(modelArgs);	}	if ((value = options.getProperty("feature-args")) != null) {	    featureArgs = value;	}	if ((value = options.getProperty("modelGraph")) != null) {	    modelGraphType = value;	}    }    void  allocModel() throws Exception {	// add any code related to dependency/consistency amongst paramter	// values here..	if (modelGraphType.equals("semi-markov")) {	    if (options.getInt("debugLvl") > 1) {		Util.printDbg("Creating semi-markov model");	    }	    NestedFeatureGenImpl nfgen = new NestedFeatureGenImpl(nlabels,options);	    featureGen = nfgen;	    crfModel = new NestedCRF(featureGen.numStates(),nfgen,options);	} else 	{	    featureGen = new FeatureGenImpl(modelGraphType, nlabels);	    crfModel=new CRF(featureGen.numStates(),featureGen,options);        }    }    class TestRecord implements SegmentDataSequence {    	String seq[];    	int path[];    	TestRecord(String s[]) {    	    seq=s;    	    path=new int[seq.length];    	}    	void init(String s[]) {    	    seq = s;    	    if ((path == null) || (path.length < seq.length)) {    		path = new int[seq.length];    	    }    	}    	public void set_y(int i, int l) {path[i] = l;} // not applicable for training data.    	public int y(int i) {return path[i];}    	public int length() {return seq.length;}    	public Object x(int i) {return seq[i];}    	/* (non-Javadoc)    	 * @see iitb.CRF.SegmentDataSequence#getSegmentEnd(int)    	 */    	public int getSegmentEnd(int segmentStart) {    		if ((segmentStart > 0) && (y(segmentStart) == y(segmentStart-1)))    			return -1;    		for (int i = segmentStart+1; i < length(); i++) {    			if (y(i)!= y(segmentStart))    				return i-1;    		}    		return length()-1;    	}    	/* (non-Javadoc)    	 * @see iitb.CRF.SegmentDataSequence#setSegment(int, int, int)    	 */    	public void setSegment(int segmentStart, int segmentEnd, int y) {    		for (int i = segmentStart; i <= segmentEnd; i++)    			set_y(i,y);    	}        };     public int[] segment(TestRecord testRecord, int[] groupedToks, String collect[]) {	for (int i = 0; i < testRecord.length(); i++)	    testRecord.seq[i] = AlphaNumericPreprocessor.preprocess(testRecord.seq[i]);	crfModel.apply(testRecord);	featureGen.mapStatesToLabels(testRecord);	int path[] = testRecord.path;	for (int i = 0; i < nlabels; i++)	    collect[i] = null;	for(int i=0 ; i<testRecord.length() ; i++) {	    // System.out.println(testRecord.seq[i] + " " + path[i]);	    int snew=path[i];	    if (snew >= 0) {		if (collect[snew]==null) {		    collect[snew]=testRecord.seq[i];		} else {		    collect[snew]=collect[snew]+" "+testRecord.seq[i];		}	    }	}	return path;    }    public void train() throws Exception {	DataCruncher.createRaw(baseDir+"/data/"+inName+"/"+inName+".train",tagDelimit);	File dir=new File(baseDir+"/learntModels/"+outDir);	dir.mkdirs();	TrainData trainData = DataCruncher.readTagged(nlabels,baseDir+"/data/"+inName+"/"+inName+".train",baseDir+"/data/"+inName+"/"+inName+".train",delimit,tagDelimit,impDelimit,labelMap);	AlphaNumericPreprocessor.preprocess(trainData,nlabels);	allocModel();	featureGen.train(trainData);	double featureWts[] = crfModel.train(trainData);	if (options.getInt("debugLvl") > 1) {	    Util.printDbg("Training done");	}	crfModel.write(baseDir+"/learntModels/"+outDir+"/crf");	featureGen.write(baseDir+"/learntModels/"+outDir+"/features");	if (options.getInt("debugLvl") > 1) {	    Util.printDbg("Writing model to "+ baseDir+"/learntModels/"+outDir+"/crf");	}	if (options.getProperty("showModel") != null) {	    featureGen.displayModel(featureWts);	}    }    public void test() throws Exception {	allocModel();	featureGen.read(baseDir+"/learntModels/"+outDir+"/features");	crfModel.read(baseDir+"/learntModels/"+outDir+"/crf");	doTest();    }    public void doTest() throws Exception {	File dir=new File(baseDir+"/out/"+outDir);	dir.mkdirs();	TestData testData = new TestData(baseDir+"/data/"+inName+"/"+inName+".test",delimit,impDelimit,groupDelimit);	TestDataWrite tdw = new TestDataWrite(baseDir+"/out/"+outDir+"/"+inName+".test",baseDir+"/data/"+inName+"/"+inName+".test",delimit,tagDelimit,impDelimit,labelMap);		String collect[] = new String[nlabels];	TestRecord testRecord = new TestRecord(collect);	for(String seq[] = testData.nextRecord();  seq != null; 	    seq = testData.nextRecord()) {	    testRecord.init(seq);	    if (options.getInt("debugLvl") > 1) {		Util.printDbg("Invoking segment on " + seq);	    }	    int path[] = segment(testRecord, testData.groupedTokens(), collect);	    tdw.writeRecord(path,testRecord.length());	}	tdw.close();    }    TrainData taggedData = null;    int[] allLabels(TrainRecord tr) {	int[] labs = new int[tr.length()];	for (int i = 0; i < labs.length; i++)	    labs[i] = tr.y(i);	return labs;    }    String arrayToString(Object[] ar) {	String st = "";	for (int i = 0; i < ar.length; i++)	    st += (ar[i] + " ");	return st;    }    public void calc() throws Exception {	Vector s=new Vector();	TrainData tdMan = DataCruncher.readTagged(nlabels,baseDir+"/data/"+inName+"/"+inName+".test",baseDir+"/data/"+inName+"/"+inName+".test",delimit,tagDelimit,impDelimit,labelMap);	TrainData tdAuto = DataCruncher.readTagged(nlabels,baseDir+"/out/"+outDir+"/"+inName+".test",baseDir+"/data/"+inName+"/"+inName+".test",delimit,tagDelimit,impDelimit,labelMap);	DataCruncher.readRaw(s,baseDir+"/data/"+inName+"/"+inName+".test","","");	int len=tdAuto.size();	int truePos[]=new int[nlabels+1];	int totalMarkedPos[]=new int[nlabels+1];	int totalPos[]=new int[nlabels+1];	int confuseMatrix[][]=new int[nlabels][nlabels];	boolean printDetails = (options.getInt("debugLvl") > 0);	if (tdAuto.size() != tdMan.size()) {	    // Sanity Check	    System.out.println("Length Mismatch - Raw: "+len+" Auto: "+tdAuto.size()+" Man: "+tdMan.size());	}			for(int i=0 ; i<len ; i++) {	    String raw[]=(String [])(s.get(i));	    TrainRecord trMan = tdMan.nextRecord();	    TrainRecord trAuto = tdAuto.nextRecord();	    int tokenMan[]=allLabels(trMan);	    int tokenAuto[]=allLabels(trAuto);	    if (tokenMan.length!=tokenAuto.length) {		// Sanity Check		System.out.println("Length Mismatch - Manual: "+tokenMan.length+" Auto: "+tokenAuto.length);		//			continue;	    }	    // remove invalid tagging.	    boolean invalidMatch = false;	    int tlen=tokenMan.length;	    for (int j = 0; j < tlen; j++) {		if (printDetails) System.err.println(tokenMan[j] + " " + tokenAuto[j]);		if (tokenAuto[j] < 0) {		    invalidMatch = true;		    break;		}	    }	    if (invalidMatch) {		if (printDetails) System.err.println("No valid path");		continue;	    }	    int correctTokens=0;	    for(int j=0 ; j<tlen ; j++) {		totalMarkedPos[tokenAuto[j]]++;		totalMarkedPos[nlabels]++;		totalPos[tokenMan[j]]++;		totalPos[nlabels]++;		confuseMatrix[tokenMan[j]][tokenAuto[j]]++;		if (tokenAuto[j]==tokenMan[j]) {		    correctTokens++;		    truePos[tokenMan[j]]++;		    truePos[nlabels]++;		}	    }	    if (printDetails) System.err.println("Stats: "+correctTokens+" "+(tlen));	    int rlen=raw.length;	    for(int j=0 ; j<rlen ; j++) {		if (printDetails) System.err.print(raw[j]+" ");	    }	    if (printDetails) System.err.println();	    for(int j=0 ; j<nlabels ; j++) {			    		String mstr = "";		for (int k = 0;  k < trMan.numSegments(j);k++)		    mstr += arrayToString(trMan.tokens(j,k));		String astr = "";		for (int k = 0;  k < trAuto.numSegments(j);k++)		    astr += arrayToString(trAuto.tokens(j,k));		if (! mstr.equalsIgnoreCase(astr))		    if (printDetails) System.err.print("W");		if (printDetails) System.err.println(j+": "+ mstr+" : "+astr);	    }	    if (printDetails) System.err.println();	}	if (confuseSet!=null) {	    System.out.println("Confusion Matrix:");	    System.out.print("M\\A");	    for(int i=0 ; i<nlabels ; i++) {		if (confuseSet[i]) {		    System.out.print("\t"+(i));		}	    }	    System.out.println();	    for(int i=0 ; i<nlabels ; i++) {		if (confuseSet[i]) {		    System.out.print(i);		    for(int j=0 ; j<nlabels ; j++) {			if (confuseSet[j]) {			    System.out.print("\t"+confuseMatrix[i][j]);			}		    }		    System.out.println();		}	    }	}	System.out.println("\n\nCalculations:");	System.out.println();	System.out.println("Label\tTrue+\tMarked+\tActual+\tPrec.\tRecall\tF1");	double prec,recall;	for(int i=0 ; i<nlabels ; i++) {	    prec=(totalMarkedPos[i]==0)?0:((double)(truePos[i]*100000/totalMarkedPos[i]))/1000;	    recall=(totalPos[i]==0)?0:((double)(truePos[i]*100000/totalPos[i]))/1000;	    System.out.println((i)+":\t"+truePos[i]+"\t"+totalMarkedPos[i]+"\t"+totalPos[i]+"\t"+prec+"\t"+recall+"\t"+2*prec*recall/(prec+recall));	}	System.out.println("---------------------------------------------------------");	prec=(totalMarkedPos[nlabels]==0)?0:((double)(truePos[nlabels]*100000/totalMarkedPos[nlabels]))/1000;	recall=(totalPos[nlabels]==0)?0:((double)(truePos[nlabels]*100000/totalPos[nlabels]))/1000;	System.out.println("Ov:\t"+truePos[nlabels]+"\t"+totalMarkedPos[nlabels]+"\t"+totalPos[nlabels]+"\t"+prec+"\t"+recall+"\t"+2*prec*recall/(prec+recall));    }};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -