📄 trainer.java
字号:
package iitb.CRF;import riso.numerical.*;import cern.colt.function.*;import cern.colt.matrix.*;import cern.colt.matrix.impl.*;import iitb.CRF.HistoryManager.*;/** * * @author Sunita Sarawagi * */ public class Trainer { protected int numF,numY; double gradLogli[]; double diag[]; double lambda[]; protected boolean reuseM, initMDone=false; protected double ExpF[]; double scale[], rLogScale[]; protected DoubleMatrix2D Mi_YY; protected DoubleMatrix1D Ri_Y; protected DoubleMatrix1D alpha_Y, newAlpha_Y; protected DoubleMatrix1D beta_Y[]; protected DoubleMatrix1D tmp_Y; static class MultFunc implements DoubleDoubleFunction { public double apply(double a, double b) {return a*b;} }; static class SumFunc implements DoubleDoubleFunction { public double apply(double a, double b) {return a+b;} }; static MultFunc multFunc = new MultFunc(); protected static SumFunc sumFunc = new SumFunc(); class MultSingle implements DoubleFunction { public double multiplicator = 1.0; public double apply(double a) {return a*multiplicator;} }; MultSingle constMultiplier = new MultSingle(); protected DataIter diter; FeatureGenerator featureGenerator; protected CrfParams params; EdgeGenerator edgeGen; protected int icall; Evaluator evaluator = null; FeatureGenCache featureGenCache; protected double norm(double ar[]) { double v = 0; for (int f = 0; f < ar.length; f++) v += ar[f]*ar[f]; return Math.sqrt(v); } public Trainer(CrfParams p) { params = p; } public void train(CRF model, DataIter data, double[] l, Evaluator eval) { init(model,data,l); evaluator = eval; if (params.debugLvl > 0) { Util.printDbg("Number of features :" + lambda.length); } doTrain(); } double getInitValue() { // returns a negative value to avoid overflow in the initial stages. // if (params.initValue == 0) // return -1*Math.log(numY); return params.initValue; } protected void init(CRF model, DataIter data, double[] l) { edgeGen = model.edgeGen; lambda = l; numY = model.numY; diter = data; featureGenerator = model.featureGenerator; numF = featureGenerator.numFeatures(); gradLogli = new double[numF]; diag = new double [ numF ]; // needed by the optimizer ExpF = new double[lambda.length]; initMatrices(); reuseM = params.reuseM; if (params.miscOptions.getProperty("cache", "false").equals("true")) { featureGenCache = new FeatureGenCache(featureGenerator); featureGenerator = featureGenCache; } else featureGenCache = null; } void initMatrices() { Mi_YY = new DenseDoubleMatrix2D(numY,numY); Ri_Y = new DenseDoubleMatrix1D(numY); alpha_Y = new DenseDoubleMatrix1D(numY); newAlpha_Y = new DenseDoubleMatrix1D(numY); tmp_Y = new DenseDoubleMatrix1D(numY); } void doTrain() { double f, xtol = 1.0e-16; // machine precision int iprint[] = new int [2], iflag[] = new int[1]; icall=0; iprint [0] = params.debugLvl-2; iprint [1] = params.debugLvl-1; iflag[0]=0; for (int j = 0 ; j < lambda.length ; j ++) { // lambda[j] = 1.0/lambda.length; lambda[j] = getInitValue(); } do { f = computeFunctionGradient(lambda,gradLogli); f = -1*f; // since the routine below minimizes and we want to maximize logli for (int j = 0 ; j < lambda.length ; j ++) { gradLogli[j] *= -1; } if ((evaluator != null) && (evaluator.evaluate() == false)) break; try { LBFGS.lbfgs (numF, params.mForHessian, lambda, f, gradLogli, false, diag, iprint, params.epsForConvergence, xtol, iflag); } catch (LBFGS.ExceptionWithIflag e) { System.err.println( "CRF: lbfgs failed.\n"+e ); if (e.iflag == -1) { System.err.println("Possible reasons could be: \n \t 1. Bug in the feature generation or data handling code\n\t 2. Not enough features to make observed feature value==expected value\n"); } return; } icall += 1; } while (( iflag[0] != 0) && (icall <= params.maxIters)); } protected double computeFunctionGradient(double lambda[], double grad[]) { initMDone=false; if (params.trainerType.equals("ll")) return computeFunctionGradientLL(lambda, grad); double logli = 0; try { for (int f = 0; f < lambda.length; f++) { grad[f] = -1*lambda[f]*params.invSigmaSquare; logli -= ((lambda[f]*lambda[f])*params.invSigmaSquare)/2; } boolean doScaling = params.doScaling; diter.startScan(); if (featureGenCache != null) featureGenCache.startDataScan(); int numRecord = 0; for (numRecord = 0; diter.hasNext(); numRecord++) { DataSequence dataSeq = (DataSequence)diter.next(); if (featureGenCache != null) featureGenCache.nextDataIndex(); if (params.debugLvl > 1) { Util.printDbg("Read next seq: " + numRecord + " logli " + logli); } alpha_Y.assign(1); for (int f = 0; f < lambda.length; f++) ExpF[f] = 0; if ((beta_Y == null) || (beta_Y.length < dataSeq.length())) { beta_Y = new DenseDoubleMatrix1D[2*dataSeq.length()]; for (int i = 0; i < beta_Y.length; i++) beta_Y[i] = new DenseDoubleMatrix1D(numY); scale = new double[2*dataSeq.length()]; } // compute beta values in a backward scan. // also scale beta-values to 1 to avoid numerical problems. scale[dataSeq.length()-1] = (doScaling)?numY:1; beta_Y[dataSeq.length()-1].assign(1.0/scale[dataSeq.length()-1]); for (int i = dataSeq.length()-1; i > 0; i--) { if (params.debugLvl > 2) { Util.printDbg("Features fired"); //featureGenerator.startScanFeaturesAt(dataSeq, i); //while (featureGenerator.hasNext()) { //Feature feature = featureGenerator.next(); //Util.printDbg(feature.toString()); //} } // compute the Mi matrix initMDone = computeLogMi(featureGenerator,lambda,dataSeq,i,Mi_YY,Ri_Y,true,reuseM,initMDone); tmp_Y.assign(beta_Y[i]); tmp_Y.assign(Ri_Y,multFunc); RobustMath.Mult(Mi_YY, tmp_Y, beta_Y[i-1],1,0,false,edgeGen); // Mi_YY.zMult(tmp_Y, beta_Y[i-1]); // need to scale the beta-s to avoid overflow scale[i-1] = doScaling?beta_Y[i-1].zSum():1; if ((scale[i-1] < 1) && (scale[i-1] > -1)) scale[i-1] = 1; constMultiplier.multiplicator = 1.0/scale[i-1]; beta_Y[i-1].assign(constMultiplier); } double thisSeqLogli = 0; for (int i = 0; i < dataSeq.length(); i++) { // compute the Mi matrix initMDone = computeLogMi(featureGenerator,lambda,dataSeq,i,Mi_YY,Ri_Y,true,reuseM,initMDone); // find features that fire at this position.. featureGenerator.startScanFeaturesAt(dataSeq, i); if (i > 0) { tmp_Y.assign(alpha_Y); RobustMath.Mult(Mi_YY, tmp_Y, newAlpha_Y,1,0,true,edgeGen); // Mi_YY.zMult(tmp_Y, newAlpha_Y,1,0,true); newAlpha_Y.assign(Ri_Y,multFunc); } else { newAlpha_Y.assign(Ri_Y); } while (featureGenerator.hasNext()) { Feature feature = featureGenerator.next(); int f = feature.index(); int yp = feature.y(); int yprev = feature.yprev(); float val = feature.value(); if ((dataSeq.y(i) == yp) && (((i-1 >= 0) && (yprev == dataSeq.y(i-1))) || (yprev < 0))) { grad[f] += val; thisSeqLogli += val*lambda[f]; } if (yprev < 0) { ExpF[f] += newAlpha_Y.get(yp)*val*beta_Y[i].get(yp); } else { ExpF[f] += alpha_Y.get(yprev)*Ri_Y.get(yp)*Mi_YY.get(yprev,yp)*val*beta_Y[i].get(yp); } } alpha_Y.assign(newAlpha_Y); // now scale the alpha-s to avoid overflow problems. constMultiplier.multiplicator = 1.0/scale[i]; alpha_Y.assign(constMultiplier); if (params.debugLvl > 2) { System.out.println("Alpha-i " + alpha_Y.toString()); System.out.println("Ri " + Ri_Y.toString()); System.out.println("Mi " + Mi_YY.toString()); System.out.println("Beta-i " + beta_Y[i].toString()); } } double Zx = alpha_Y.zSum(); thisSeqLogli -= log(Zx); // correct for the fact that alpha-s were scaled. for (int i = 0; i < dataSeq.length(); i++) { thisSeqLogli -= log(scale[i]); } logli += thisSeqLogli; // update grad. for (int f = 0; f < grad.length; f++) grad[f] -= ExpF[f]/Zx; if (params.debugLvl > 1) { System.out.println("Sequence " + thisSeqLogli + " logli " + logli + " log(Zx) " + Math.log(Zx) + " Zx " + Zx);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -