📄 collinstrainer.java
字号:
package iitb.CRF;import java.util.*;/** * Implements the CollinsVotedPerceptron training algorithm * * @author Sunita Sarawagi * */ class CollinsTrainer extends Trainer { int beamsize = 3; double beta = 0.05; boolean useUpdated = false; boolean voted = true; Soln solnPool[]; // for efficiency public CollinsTrainer(CrfParams p) { super(p); if (params.miscOptions.getProperty("beamSize") != null) beamsize = Integer.parseInt(params.miscOptions.getProperty("beamSize")); if (params.miscOptions.getProperty("beta") != null) beta = Double.parseDouble(params.miscOptions.getProperty("beta")); if (params.miscOptions.getProperty("UpdatedViterbi") != null) useUpdated = params.miscOptions.getProperty("UpdatedViterbi").equalsIgnoreCase("true"); if (params.miscOptions.getProperty("voted") != null) voted = params.miscOptions.getProperty("voted").equalsIgnoreCase("true"); } public void train(CRF model, DataIter data, double[] l, Evaluator eval) { init(model,data,l); double grad[] = gradLogli; // reusing parent's structures. Viterbi viterbiSearcher = model.getViterbi(beamsize); for (int i = 0; i < lambda.length; i++) lambda[i] = grad[i] = 0; Vector viterbiS = new Vector(); for (int t = 0; t < params.maxIters; t++) { int numErrs = 0; diter.startScan(); for (int numRecord = 0; diter.hasNext(); numRecord++) { DataSequence dataSeq = (DataSequence)diter.next(); viterbiSearcher.viterbiSearch(dataSeq,(useUpdated)?lambda:grad, false); Soln corrSoln = getCorrectSoln(dataSeq,(useUpdated)?lambda:grad); double corrScore = corrSoln.score; int maxNum = viterbiSearcher.numSolutions(); viterbiS.clear(); for (int k = 0; k < maxNum; k++) { Soln viterbi = viterbiSearcher.getBestSoln(k); if (viterbi.score < corrScore*(1-beta)) break; if ( !isCorrect(viterbi,corrSoln)) { viterbiS.add(viterbi); //System.out.println("adding " + viterbiS.size() + " " + viterbi.score + " " + corrScore + " grad " + norm(grad)); } } if (viterbiS.size() > 0) { for (; corrSoln != null; corrSoln = corrSoln.prevSoln) { boolean differenceAtI = false; for (int s = 0; s < viterbiS.size(); s++) { Soln viterbi = (Soln) viterbiS.elementAt(s); if ((viterbi == null) || !corrSoln.equals(viterbi)) { differenceAtI = true; break; } } if (differenceAtI) { numErrs++; updateWeights(corrSoln, 1.0, grad, dataSeq); for (int s = 0; s < viterbiS.size(); s++) { Soln viterbi = (Soln) viterbiS.elementAt(s); // if (within current frontier, i.e. endpoint overlaps with current segment for (;(viterbi != null) && (viterbi.pos > corrSoln.prevPos()); viterbi = viterbi.prevSoln) { updateWeights(viterbi, -1.0/viterbiS.size(), grad, dataSeq); } } /*System.out.println("gnorm at " + corrSoln.pos + " " + norm(grad)); for (int s = 0; s < viterbiS.size(); s++) { Soln viterbi = (Soln) viterbiS.elementAt(s); System.out.println(s + " viterbi " + viterbi.pos + " " + viterbi.label); }*/ } // advance all viterbi solutions.. for (int s = 0; s < viterbiS.size(); s++) { Soln viterbi = (Soln) viterbiS.elementAt(s); // if (within current frontier, i.e. endpoint overlaps with current segment for (;(viterbi != null) && (viterbi.pos > corrSoln.prevPos()); viterbi = viterbi.prevSoln); viterbiS.set(s,viterbi); } } } // voted perceptron, so add. for (int f = 0; f < lambda.length; f++) lambda[f] += grad[f]; } // all records. if (params.debugLvl > 0) Util.printDbg("Iteration " + t + " numErrs "+ numErrs); if (numErrs == 0) break; } } boolean isCorrect(Soln viterbi, Soln corr) { for (; (viterbi != null) && (corr != null); corr = corr.prevSoln, viterbi = viterbi.prevSoln) { if (!viterbi.equals(corr)) return false; } return ((viterbi == null) && (corr == null)); } int getSegmentEnd(DataSequence dataSeq, int ss) { return ss; } void startFeatureGenerator(FeatureGenerator _featureGenerator, DataSequence dataSeq, Soln soln) { _featureGenerator.startScanFeaturesAt(dataSeq, soln.pos); } void updateWeights(Soln soln, double wt, double grad[], DataSequence dataSeq) { startFeatureGenerator(featureGenerator,dataSeq,soln); while (featureGenerator.hasNext()) { Feature feature = featureGenerator.next(); int f = feature.index(); int yp = feature.y(); int yprev = feature.yprev(); float val = feature.value(); if ((soln.label == yp) && (((soln.prevPos() >= 0) && (yprev == soln.prevSoln.label)) || (yprev < 0))) { grad[f] += wt*val; /* if (soln.prevPos() < 0) System.out.println("Updating " + soln.label + " "); else System.out.println("Updating " + soln.label + " " + yprev + " " + soln.prevSoln.label); */ } } } Soln getCorrectSoln(DataSequence dataSeq, double grad[]) { int se = 0; Soln prevSoln = null; if ((solnPool == null) || solnPool.length < dataSeq.length()) { solnPool = new Soln[dataSeq.length()]; for (int i = 0; i < dataSeq.length(); solnPool[i++] = new Soln(0,0)); } for (int ss = 0; ss < dataSeq.length(); ss = se+1) { se = getSegmentEnd(dataSeq, ss); Soln soln = solnPool[ss]; soln.pos = se; soln.label = dataSeq.y(ss); soln.prevSoln = prevSoln; soln.score = (prevSoln == null)?0:prevSoln.score; startFeatureGenerator(featureGenerator,dataSeq,soln); while (featureGenerator.hasNext()) { Feature feature = featureGenerator.next(); int f = feature.index(); int yp = feature.y(); int yprev = feature.yprev(); float val = feature.value(); if ((soln.label == yp) && (((soln.prevPos() >= 0) && (yprev == soln.prevSoln.label)) || (yprev < 0))) { soln.score += grad[f]*val; } } prevSoln = soln; } return prevSoln; }};
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -