📄 segment.java
字号:
package iitb.Segment;import java.io.*;import java.util.*;import iitb.CRF.*;import iitb.Model.*;import iitb.Utils.*;/** * * @author Sunita Sarawagi * */ public class Segment { String inName; String outDir; String baseDir=""; int nlabels; String delimit=" \t"; // used to define token boundaries String tagDelimit="|"; // seperator between tokens and tag number String impDelimit=""; // delimiters to be retained for tagging String groupDelimit=null; boolean confuseSet[]=null; boolean validate = false; String mapTagString = null; String smoothType = ""; String modelArgs = ""; String featureArgs = ""; String modelGraphType = "naive"; LabelMap labelMap; Options options; CRF crfModel; FeatureGenImpl featureGen; public FeatureGenerator featureGenerator() {return featureGen;} public static void main(String argv[]) throws Exception { if (argv.length < 3) { System.out.println("Usage: java Tagger train|test|calc -f <conf-file>"); return; } Segment segment = new Segment(); segment.parseConf(argv); if (argv[0].toLowerCase().equals("all")) { segment.train(); segment.doTest(); segment.calc(); } if (argv[0].toLowerCase().equals("train")) { segment.train(); } if (argv[0].toLowerCase().equals("test")) { segment.test(); } if (argv[0].toLowerCase().equals("calc")) { segment.calc(); } } public void parseConf(String argv[]) throws Exception { options = new Options(); int startIndex = 1; if ((argv.length >= 2) && argv[1].toLowerCase().equals("-f")) { options.load(new FileInputStream(argv[2])); } options.add(3, argv); processArgs(); } public void processArgs() throws Exception { String value = null; if ((value = options.getMandatoryProperty("numlabels")) != null) { nlabels=Integer.parseInt(value); } if ((value = options.getProperty("binary")) != null) { nlabels = 2; labelMap = new BinaryLabelMap(options.getInt("binary")); } else { labelMap = new LabelMap(); } if ((value = options.getMandatoryProperty("inname")) != null) { inName=new String(value); } if ((value = options.getMandatoryProperty("outdir")) != null) { outDir=new String(value); } if ((value = options.getProperty("basedir")) != null) { baseDir=new String(value); } if ((value = options.getProperty("tagdelimit")) != null) { tagDelimit=new String(value); } // delimiters that will be ignored. if ((value = options.getProperty("delimit")) != null) { delimit=new String(value); } if ((value = options.getProperty("impdelimit")) != null) { impDelimit=new String(value); } if ((value = options.getProperty("groupdelimit")) != null) { groupDelimit=value; } if ((value = options.getProperty("confusion")) != null) { StringTokenizer confuse=new StringTokenizer(value,", "); int confuseSize=confuse.countTokens(); confuseSet=new boolean[nlabels+1]; for(int i=0 ; i<confuseSize ; i++) { confuseSet[Integer.parseInt(confuse.nextToken())]=true; } } if ((value = options.getProperty("map-tags")) != null) { mapTagString = value; } if ((value = options.getProperty("validate")) != null) { validate = true; } if ((value = options.getProperty("model-args")) != null) { modelArgs = value; System.out.println(modelArgs); } if ((value = options.getProperty("feature-args")) != null) { featureArgs = value; } if ((value = options.getProperty("modelGraph")) != null) { modelGraphType = value; } } void allocModel() throws Exception { // add any code related to dependency/consistency amongst paramter // values here.. if (modelGraphType.equals("semi-markov")) { if (options.getInt("debugLvl") > 1) { Util.printDbg("Creating semi-markov model"); } NestedFeatureGenImpl nfgen = new NestedFeatureGenImpl(nlabels,options); featureGen = nfgen; crfModel = new NestedCRF(featureGen.numStates(),nfgen,options); } else { featureGen = new FeatureGenImpl(modelGraphType, nlabels); crfModel=new CRF(featureGen.numStates(),featureGen,options); } } class TestRecord implements SegmentDataSequence { String seq[]; int path[]; TestRecord(String s[]) { seq=s; path=new int[seq.length]; } void init(String s[]) { seq = s; if ((path == null) || (path.length < seq.length)) { path = new int[seq.length]; } } public void set_y(int i, int l) {path[i] = l;} // not applicable for training data. public int y(int i) {return path[i];} public int length() {return seq.length;} public Object x(int i) {return seq[i];} /* (non-Javadoc) * @see iitb.CRF.SegmentDataSequence#getSegmentEnd(int) */ public int getSegmentEnd(int segmentStart) { if ((segmentStart > 0) && (y(segmentStart) == y(segmentStart-1))) return -1; for (int i = segmentStart+1; i < length(); i++) { if (y(i)!= y(segmentStart)) return i-1; } return length()-1; } /* (non-Javadoc) * @see iitb.CRF.SegmentDataSequence#setSegment(int, int, int) */ public void setSegment(int segmentStart, int segmentEnd, int y) { for (int i = segmentStart; i <= segmentEnd; i++) set_y(i,y); } }; public int[] segment(TestRecord testRecord, int[] groupedToks, String collect[]) { for (int i = 0; i < testRecord.length(); i++) testRecord.seq[i] = AlphaNumericPreprocessor.preprocess(testRecord.seq[i]); crfModel.apply(testRecord); featureGen.mapStatesToLabels(testRecord); int path[] = testRecord.path; for (int i = 0; i < nlabels; i++) collect[i] = null; for(int i=0 ; i<testRecord.length() ; i++) { // System.out.println(testRecord.seq[i] + " " + path[i]); int snew=path[i]; if (snew >= 0) { if (collect[snew]==null) { collect[snew]=testRecord.seq[i]; } else { collect[snew]=collect[snew]+" "+testRecord.seq[i]; } } } return path; } public void train() throws Exception { DataCruncher.createRaw(baseDir+"/data/"+inName+"/"+inName+".train",tagDelimit); File dir=new File(baseDir+"/learntModels/"+outDir); dir.mkdirs(); TrainData trainData = DataCruncher.readTagged(nlabels,baseDir+"/data/"+inName+"/"+inName+".train",baseDir+"/data/"+inName+"/"+inName+".train",delimit,tagDelimit,impDelimit,labelMap); AlphaNumericPreprocessor.preprocess(trainData,nlabels); allocModel(); featureGen.train(trainData); double featureWts[] = crfModel.train(trainData); if (options.getInt("debugLvl") > 1) { Util.printDbg("Training done"); } crfModel.write(baseDir+"/learntModels/"+outDir+"/crf"); featureGen.write(baseDir+"/learntModels/"+outDir+"/features"); if (options.getInt("debugLvl") > 1) { Util.printDbg("Writing model to "+ baseDir+"/learntModels/"+outDir+"/crf"); } if (options.getProperty("showModel") != null) { featureGen.displayModel(featureWts); } } public void test() throws Exception { allocModel(); featureGen.read(baseDir+"/learntModels/"+outDir+"/features"); crfModel.read(baseDir+"/learntModels/"+outDir+"/crf"); doTest(); } public void doTest() throws Exception { File dir=new File(baseDir+"/out/"+outDir); dir.mkdirs(); TestData testData = new TestData(baseDir+"/data/"+inName+"/"+inName+".test",delimit,impDelimit,groupDelimit); TestDataWrite tdw = new TestDataWrite(baseDir+"/out/"+outDir+"/"+inName+".test",baseDir+"/data/"+inName+"/"+inName+".test",delimit,tagDelimit,impDelimit,labelMap); String collect[] = new String[nlabels]; TestRecord testRecord = new TestRecord(collect); for(String seq[] = testData.nextRecord(); seq != null; seq = testData.nextRecord()) { testRecord.init(seq); if (options.getInt("debugLvl") > 1) { Util.printDbg("Invoking segment on " + seq); } int path[] = segment(testRecord, testData.groupedTokens(), collect); tdw.writeRecord(path,testRecord.length()); } tdw.close(); } TrainData taggedData = null; int[] allLabels(TrainRecord tr) { int[] labs = new int[tr.length()]; for (int i = 0; i < labs.length; i++) labs[i] = tr.y(i); return labs; } String arrayToString(Object[] ar) { String st = ""; for (int i = 0; i < ar.length; i++) st += (ar[i] + " "); return st; } public void calc() throws Exception { Vector s=new Vector(); TrainData tdMan = DataCruncher.readTagged(nlabels,baseDir+"/data/"+inName+"/"+inName+".test",baseDir+"/data/"+inName+"/"+inName+".test",delimit,tagDelimit,impDelimit,labelMap); TrainData tdAuto = DataCruncher.readTagged(nlabels,baseDir+"/out/"+outDir+"/"+inName+".test",baseDir+"/data/"+inName+"/"+inName+".test",delimit,tagDelimit,impDelimit,labelMap); DataCruncher.readRaw(s,baseDir+"/data/"+inName+"/"+inName+".test","",""); int len=tdAuto.size(); int truePos[]=new int[nlabels+1]; int totalMarkedPos[]=new int[nlabels+1]; int totalPos[]=new int[nlabels+1]; int confuseMatrix[][]=new int[nlabels][nlabels]; boolean printDetails = (options.getInt("debugLvl") > 0); if (tdAuto.size() != tdMan.size()) { // Sanity Check System.out.println("Length Mismatch - Raw: "+len+" Auto: "+tdAuto.size()+" Man: "+tdMan.size()); } for(int i=0 ; i<len ; i++) { String raw[]=(String [])(s.get(i)); TrainRecord trMan = tdMan.nextRecord(); TrainRecord trAuto = tdAuto.nextRecord(); int tokenMan[]=allLabels(trMan); int tokenAuto[]=allLabels(trAuto); if (tokenMan.length!=tokenAuto.length) { // Sanity Check System.out.println("Length Mismatch - Manual: "+tokenMan.length+" Auto: "+tokenAuto.length); // continue; } // remove invalid tagging. boolean invalidMatch = false; int tlen=tokenMan.length; for (int j = 0; j < tlen; j++) { if (printDetails) System.err.println(tokenMan[j] + " " + tokenAuto[j]); if (tokenAuto[j] < 0) { invalidMatch = true; break; } } if (invalidMatch) { if (printDetails) System.err.println("No valid path"); continue; } int correctTokens=0; for(int j=0 ; j<tlen ; j++) { totalMarkedPos[tokenAuto[j]]++; totalMarkedPos[nlabels]++; totalPos[tokenMan[j]]++; totalPos[nlabels]++; confuseMatrix[tokenMan[j]][tokenAuto[j]]++; if (tokenAuto[j]==tokenMan[j]) { correctTokens++; truePos[tokenMan[j]]++; truePos[nlabels]++; } } if (printDetails) System.err.println("Stats: "+correctTokens+" "+(tlen)); int rlen=raw.length; for(int j=0 ; j<rlen ; j++) { if (printDetails) System.err.print(raw[j]+" "); } if (printDetails) System.err.println(); for(int j=0 ; j<nlabels ; j++) { String mstr = ""; for (int k = 0; k < trMan.numSegments(j);k++) mstr += arrayToString(trMan.tokens(j,k)); String astr = ""; for (int k = 0; k < trAuto.numSegments(j);k++) astr += arrayToString(trAuto.tokens(j,k)); if (! mstr.equalsIgnoreCase(astr)) if (printDetails) System.err.print("W"); if (printDetails) System.err.println(j+": "+ mstr+" : "+astr); } if (printDetails) System.err.println(); } if (confuseSet!=null) { System.out.println("Confusion Matrix:"); System.out.print("M\\A"); for(int i=0 ; i<nlabels ; i++) { if (confuseSet[i]) { System.out.print("\t"+(i)); } } System.out.println(); for(int i=0 ; i<nlabels ; i++) { if (confuseSet[i]) { System.out.print(i); for(int j=0 ; j<nlabels ; j++) { if (confuseSet[j]) { System.out.print("\t"+confuseMatrix[i][j]); } } System.out.println(); } } } System.out.println("\n\nCalculations:"); System.out.println(); System.out.println("Label\tTrue+\tMarked+\tActual+\tPrec.\tRecall\tF1"); double prec,recall; for(int i=0 ; i<nlabels ; i++) { prec=(totalMarkedPos[i]==0)?0:((double)(truePos[i]*100000/totalMarkedPos[i]))/1000; recall=(totalPos[i]==0)?0:((double)(truePos[i]*100000/totalPos[i]))/1000; System.out.println((i)+":\t"+truePos[i]+"\t"+totalMarkedPos[i]+"\t"+totalPos[i]+"\t"+prec+"\t"+recall+"\t"+2*prec*recall/(prec+recall)); } System.out.println("---------------------------------------------------------"); prec=(totalMarkedPos[nlabels]==0)?0:((double)(truePos[nlabels]*100000/totalMarkedPos[nlabels]))/1000; recall=(totalPos[nlabels]==0)?0:((double)(truePos[nlabels]*100000/totalPos[nlabels]))/1000; System.out.println("Ov:\t"+truePos[nlabels]+"\t"+totalMarkedPos[nlabels]+"\t"+totalPos[nlabels]+"\t"+prec+"\t"+recall+"\t"+2*prec*recall/(prec+recall)); }};
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -