⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 collinstrainer.java

📁 CRF1.2
💻 JAVA
字号:
package iitb.CRF;import java.util.*;/** * Implements the CollinsVotedPerceptron training algorithm * * @author Sunita Sarawagi * */ class CollinsTrainer extends Trainer {    int beamsize = 3;    double beta = 0.05;    boolean useUpdated = false;    boolean voted = true;    Soln solnPool[]; // for efficiency    public CollinsTrainer(CrfParams p) {	super(p);	if (params.miscOptions.getProperty("beamSize") != null)	    beamsize = Integer.parseInt(params.miscOptions.getProperty("beamSize"));	if (params.miscOptions.getProperty("beta") != null)	    beta = Double.parseDouble(params.miscOptions.getProperty("beta"));	if (params.miscOptions.getProperty("UpdatedViterbi") != null) 	    useUpdated = params.miscOptions.getProperty("UpdatedViterbi").equalsIgnoreCase("true");	if (params.miscOptions.getProperty("voted") != null) 	    voted = params.miscOptions.getProperty("voted").equalsIgnoreCase("true");    }    public void train(CRF model, DataIter data, double[] l, Evaluator eval) {	init(model,data,l);	double grad[] = gradLogli; // reusing parent's structures.	Viterbi viterbiSearcher = model.getViterbi(beamsize); 	for (int i = 0; i < lambda.length; i++)	    lambda[i] = grad[i] = 0;	Vector viterbiS = new Vector();	for (int t = 0; t < params.maxIters; t++) {	    int numErrs = 0;	    diter.startScan();	    for (int numRecord = 0; diter.hasNext(); numRecord++) {		DataSequence dataSeq = (DataSequence)diter.next();		viterbiSearcher.viterbiSearch(dataSeq,(useUpdated)?lambda:grad, false);		Soln corrSoln = getCorrectSoln(dataSeq,(useUpdated)?lambda:grad);		double corrScore = corrSoln.score;		int maxNum = viterbiSearcher.numSolutions();		viterbiS.clear();		for (int k = 0; k < maxNum; k++) {		    Soln viterbi  = viterbiSearcher.getBestSoln(k);		    if (viterbi.score < corrScore*(1-beta))			break;		    if ( !isCorrect(viterbi,corrSoln)) {			viterbiS.add(viterbi);			//System.out.println("adding " + viterbiS.size() + " " + viterbi.score + " " + corrScore + " grad " + norm(grad));		    }		}		if (viterbiS.size() > 0) {		    for (; corrSoln != null; corrSoln = corrSoln.prevSoln) {		    boolean differenceAtI = false;		    for (int s = 0; s < viterbiS.size(); s++) {			Soln viterbi = (Soln) viterbiS.elementAt(s);			if ((viterbi == null) || !corrSoln.equals(viterbi)) {			    differenceAtI = true;			    break;			}		    }		    if (differenceAtI) {			numErrs++;			updateWeights(corrSoln, 1.0, grad, dataSeq);			for (int s = 0; s < viterbiS.size(); s++) {			    Soln viterbi = (Soln) viterbiS.elementAt(s);			    // if (within current frontier, i.e. endpoint overlaps with current segment			    for (;(viterbi != null) && (viterbi.pos > corrSoln.prevPos()); viterbi = viterbi.prevSoln) { 				updateWeights(viterbi, -1.0/viterbiS.size(), grad, dataSeq);			    }			}			/*System.out.println("gnorm at " + corrSoln.pos + " " + norm(grad));			for (int s = 0; s < viterbiS.size(); s++) {			    Soln viterbi = (Soln) viterbiS.elementAt(s);			    System.out.println(s + " viterbi " + viterbi.pos + " " + viterbi.label);			    }*/		    }		    // advance all viterbi solutions..		    for (int s = 0; s < viterbiS.size(); s++) {			Soln viterbi = (Soln) viterbiS.elementAt(s);			// if (within current frontier, i.e. endpoint overlaps with current segment			for (;(viterbi != null) && (viterbi.pos > corrSoln.prevPos()); viterbi = viterbi.prevSoln);			viterbiS.set(s,viterbi);		    }		    }		}		// voted perceptron, so add.		for (int f = 0; f < lambda.length; f++)		    lambda[f] += grad[f];			    }  // all records.	    if (params.debugLvl > 0)		Util.printDbg("Iteration " + t + " numErrs "+ numErrs);	    if (numErrs == 0)		break;	}    }    boolean isCorrect(Soln viterbi, Soln corr) { 	for (; (viterbi != null) && (corr != null); corr = corr.prevSoln, viterbi = viterbi.prevSoln) { 	    if (!viterbi.equals(corr))		return false;	}	return ((viterbi == null) && (corr == null));    }      int getSegmentEnd(DataSequence dataSeq, int ss) {	return ss;    }    void startFeatureGenerator(FeatureGenerator _featureGenerator, DataSequence dataSeq, Soln soln) {	_featureGenerator.startScanFeaturesAt(dataSeq, soln.pos);    }    void updateWeights(Soln soln, double wt, double grad[], DataSequence dataSeq) {	startFeatureGenerator(featureGenerator,dataSeq,soln);	while (featureGenerator.hasNext()) { 	    Feature feature = featureGenerator.next();	    int f = feature.index();	    int yp = feature.y();	    int yprev = feature.yprev();	    float val = feature.value();			    	    if ((soln.label == yp) && (((soln.prevPos() >= 0) && (yprev == soln.prevSoln.label)) || (yprev < 0))) {		grad[f] += wt*val;		/*		if (soln.prevPos() < 0)		    System.out.println("Updating " + soln.label +  " ");		else		    System.out.println("Updating " + soln.label +  " " + yprev + " " + soln.prevSoln.label);		*/	    }	}    }    Soln getCorrectSoln(DataSequence dataSeq, double grad[]) {	int se = 0;	Soln prevSoln = null;	if ((solnPool == null) || solnPool.length < dataSeq.length()) {	    solnPool = new Soln[dataSeq.length()];	    for (int i = 0; i < dataSeq.length(); solnPool[i++] = new Soln(0,0));	}	for (int ss = 0; ss < dataSeq.length(); ss = se+1) {	    se = getSegmentEnd(dataSeq, ss); 	    Soln soln = solnPool[ss];	    soln.pos = se;	    soln.label = dataSeq.y(ss);	    soln.prevSoln = prevSoln;	    soln.score = (prevSoln == null)?0:prevSoln.score;	    startFeatureGenerator(featureGenerator,dataSeq,soln);	    while (featureGenerator.hasNext()) { 		Feature feature = featureGenerator.next();		int f = feature.index();		int yp = feature.y();		int yprev = feature.yprev();		float val = feature.value();		if ((soln.label == yp) && (((soln.prevPos() >= 0) && (yprev == soln.prevSoln.label)) || (yprev < 0))) {		    soln.score += grad[f]*val;		}	    }	    prevSoln = soln;	}	return prevSoln;    }};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -