⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nestedtrainer.java

📁 这是一个CRF(条件随机域)算法的实现,希望能对从事算法的有些帮助.
💻 JAVA
字号:
package iitb.CRF;import cern.colt.matrix.impl.*;/** * * @author Sunita Sarawagi * */ class NestedTrainer extends Trainer {    public NestedTrainer(CrfParams p) {	super(p);    }    DenseDoubleMatrix1D alpha_Y_Array[];    protected double computeFunctionGradient(double lambda[], double grad[]) {	if (params.doScaling)	    return computeFunctionGradientLL(lambda,  grad);	try {	FeatureGeneratorNested featureGenNested = (FeatureGeneratorNested)featureGenerator;	double logli = 0;	for (int f = 0; f < lambda.length; f++) {	    grad[f] = -1*lambda[f]*params.invSigmaSquare;	    logli -= ((lambda[f]*lambda[f])*params.invSigmaSquare)/2;	}	boolean doScaling = false; // scaling as implemented below does not work.	diter.startScan();	if (featureGenCache != null) featureGenCache.startDataScan();	for (int numRecord = 0; diter.hasNext(); numRecord++) {	    SegmentDataSequence dataSeq = (SegmentDataSequence)diter.next();	    if (featureGenCache != null) featureGenCache.nextDataIndex();	    if (params.debugLvl > 1)		Util.printDbg("Read next seq: " + numRecord + " logli " + logli);	    for (int f = 0; f < lambda.length; f++)		ExpF[f] = 0;	    int base = -1;	    if ((alpha_Y_Array == null) || (alpha_Y_Array.length < dataSeq.length()-base)) {		alpha_Y_Array = new DenseDoubleMatrix1D[2*dataSeq.length()];		for (int i = 0; i < alpha_Y_Array.length; i++)		    alpha_Y_Array[i] = new DenseDoubleMatrix1D(numY);	    }	    if ((beta_Y == null) || (beta_Y.length < dataSeq.length())) {		beta_Y = new DenseDoubleMatrix1D[2*dataSeq.length()];		for (int i = 0; i < beta_Y.length; i++)		    beta_Y[i] = new DenseDoubleMatrix1D(numY);		scale = new double[2*dataSeq.length()];	    }	    // compute beta values in a backward scan.	    // also scale beta-values as much as possible to avoid numerical overflows	    beta_Y[dataSeq.length()-1].assign(1.0);	    scale[dataSeq.length()-1]=1;	    for (int i = dataSeq.length()-2; i >= 0; i--) {		// we need to do delayed scaling, so scale everything by the first element of the current window.		if (doScaling && (i + featureGenNested.maxMemory() < dataSeq.length())) {		    int iL = i + featureGenNested.maxMemory();		    scale[iL] = beta_Y[iL].zSum();		    constMultiplier.multiplicator = 1.0/scale[iL];		    for (int j = i+1; j <= iL; j++) {			beta_Y[j].assign(constMultiplier);		    }		}		beta_Y[i].assign(0);		scale[i] = 1;		for (int ell = 1; (ell <= featureGenNested.maxMemory()) && (i+ell < dataSeq.length()); ell++) {		    		    // compute the Mi matrix		    featureGenNested.startScanFeaturesAt(dataSeq, i, i+ell);		    //if (! featureGenNested.hasNext())		    //	break;		    initMDone = computeLogMi(featureGenNested,lambda,Mi_YY,Ri_Y,true,reuseM,initMDone);		    tmp_Y.assign(beta_Y[i+ell]);		    tmp_Y.assign(Ri_Y,multFunc);		    Mi_YY.zMult(tmp_Y, beta_Y[i],1,1,false);		}	    }	    double thisSeqLogli = 0;	    alpha_Y_Array[0].assign(1);	    int segmentStart = 0;	    int segmentEnd = -1;	    boolean invalid = false;	    for (int i = 0; i < dataSeq.length(); i++) {		if (segmentEnd < i) {		    segmentStart = i;		    segmentEnd = dataSeq.getSegmentEnd(i);		}		if (segmentEnd-segmentStart+1 > featureGenNested.maxMemory()) {		    if (icall <= 1) {			System.out.println("Ignoring record with segment length greater than maxMemory " + numRecord);		    }		    invalid = true;		    break;		}		alpha_Y_Array[i-base].assign(0);		// compute the scale adjustment.		float scaleProduct = 1;		for (int j = i-featureGenNested.maxMemory()-base; j <= i-1; j++)		    if (j >= 0) scaleProduct *= scale[j];		for (int ell = 1; (ell <= featureGenNested.maxMemory()) && (i-ell >= base); ell++) {		    // compute the Mi matrix		    featureGenNested.startScanFeaturesAt(dataSeq, i-ell,i);		    //		    if (!featureGenNested.hasNext())		    //	break;		    initMDone = computeLogMi(featureGenNested,lambda,Mi_YY,Ri_Y,true,reuseM,initMDone);		    // find features that fire at this position..		    featureGenNested.startScanFeaturesAt(dataSeq, i-ell,i);		    boolean isSegment = ((i-ell+1==segmentStart) && (i == segmentEnd));		    while (featureGenNested.hasNext()) { 			Feature feature = featureGenNested.next();			int f = feature.index();		    			int yp = feature.y();			int yprev = feature.yprev();			float val = feature.value();			boolean allEllMatch = isSegment && (dataSeq.y(i) == yp);			if (allEllMatch && (((i-ell >= 0) && (yprev == dataSeq.y(i-ell))) || (yprev < 0))) {			    grad[f] += val;			    thisSeqLogli += val*lambda[f];			}			if (yprev < 0) {			    for (yprev = 0; yprev < Mi_YY.rows(); yprev++) 				ExpF[f] += (alpha_Y_Array[i-ell-base].get(yprev)*Ri_Y.get(yp)*Mi_YY.get(yprev,yp)*val*beta_Y[i].get(yp))/scaleProduct;			} else {			    ExpF[f] += (alpha_Y_Array[i-ell-base].get(yprev)*Ri_Y.get(yp)*Mi_YY.get(yprev,yp)*val*beta_Y[i].get(yp))/scaleProduct;			}		    }		    Mi_YY.zMult(alpha_Y_Array[i-ell-base],tmp_Y,1,0,true);		    tmp_Y.assign(Ri_Y,multFunc);		    alpha_Y_Array[i-base].assign(tmp_Y, sumFunc);		}				// now do the delayed scaling of  the alpha-s to avoid overflow problems.		if (i-base-featureGenNested.maxMemory() >= 0) {		    int iL = i-base-featureGenNested.maxMemory();		    constMultiplier.multiplicator = 1.0/scale[iL];		    for (int j = iL; j <= i-base; j++) {			alpha_Y_Array[j].assign(constMultiplier);		    }		}		if (params.debugLvl > 2) {		    System.out.println("Alpha-i " + alpha_Y_Array[i-base].toString());		    System.out.println("Ri " + Ri_Y.toString());		    System.out.println("Mi " + Mi_YY.toString());		    System.out.println("Beta-i " + beta_Y[i].toString());		}		if (params.debugLvl > 1) {		    System.out.println(" pos "  + i + " " + thisSeqLogli);		}	    }	    if (invalid)		continue;	    double Zx = alpha_Y_Array[dataSeq.length()-1-base].zSum();	    thisSeqLogli -= log(Zx);	    // correct for the fact that alpha-s were scaled.	    for (int i = 0; i < dataSeq.length()-base-featureGenNested.maxMemory(); i++) {	      thisSeqLogli -= log(scale[i]);	    }	    logli += thisSeqLogli;	    // update grad.	    for (int f = 0; f < grad.length; f++)		grad[f] -= ExpF[f]/Zx;	    	    if (params.debugLvl > 1) {		System.out.println("Sequence "  + thisSeqLogli + " " + logli + " " + Zx);		System.out.println("Last Alpha-i " + alpha_Y_Array[dataSeq.length()-1-base].toString());	    }	}	if (params.debugLvl > 2) {	    for (int f = 0; f < lambda.length; f++)		System.out.print(lambda[f] + " ");	    System.out.println(" :x");	    for (int f = 0; f < lambda.length; f++)		System.out.print(grad[f] + " ");	    System.out.println(" :g");	}		if (params.debugLvl > 0)	    Util.printDbg("Iter " + icall + " loglikelihood "+logli + " gnorm " + norm(grad) + " xnorm "+ norm(lambda));	return logli;	} catch (Exception e) {	    e.printStackTrace();	    System.exit(0);	}	return 0;    }    protected double computeFunctionGradientLL(double lambda[], double grad[]) {	try {	FeatureGeneratorNested featureGenNested = (FeatureGeneratorNested)featureGenerator;	double logli = 0;	for (int f = 0; f < lambda.length; f++) {	    grad[f] = -1*lambda[f]*params.invSigmaSquare;	    logli -= ((lambda[f]*lambda[f])*params.invSigmaSquare)/2;	}	diter.startScan();	if (featureGenCache != null) featureGenCache.startDataScan();	for (int numRecord = 0; diter.hasNext(); numRecord++) {	    SegmentDataSequence dataSeq = (SegmentDataSequence)diter.next();	    if (featureGenCache != null) featureGenCache.nextDataIndex();	    if (params.debugLvl > 1)		Util.printDbg("Read next seq: " + numRecord + " logli " + logli);	    for (int f = 0; f < lambda.length; f++)		ExpF[f] = RobustMath.LOG0;	    int base = -1;	    if ((alpha_Y_Array == null) || (alpha_Y_Array.length < dataSeq.length()-base)) {		alpha_Y_Array = new DenseDoubleMatrix1D[2*dataSeq.length()];		for (int i = 0; i < alpha_Y_Array.length; i++)		    alpha_Y_Array[i] = new DenseDoubleMatrix1D(numY);	    }	    if ((beta_Y == null) || (beta_Y.length < dataSeq.length())) {		beta_Y = new DenseDoubleMatrix1D[2*dataSeq.length()];		for (int i = 0; i < beta_Y.length; i++)		    beta_Y[i] = new DenseDoubleMatrix1D(numY);	    }	    // compute beta values in a backward scan.	    // also scale beta-values as much as possible to avoid numerical overflows	    beta_Y[dataSeq.length()-1].assign(0);	    for (int i = dataSeq.length()-2; i >= 0; i--) {		beta_Y[i].assign(RobustMath.LOG0);		for (int ell = 1; (ell <= featureGenNested.maxMemory()) && (i+ell < dataSeq.length()); ell++) {		    		    // compute the Mi matrix		    featureGenNested.startScanFeaturesAt(dataSeq, i, i+ell);		    //if (! featureGenNested.hasNext())		    //	break;		    initMDone = computeLogMi(featureGenNested,lambda,Mi_YY,Ri_Y,false,reuseM,initMDone);		    tmp_Y.assign(beta_Y[i+ell]);		    tmp_Y.assign(Ri_Y,sumFunc);		    RobustMath.logMult(Mi_YY, tmp_Y, beta_Y[i],1,1,false,edgeGen);		}	    }	    double thisSeqLogli = 0;	    alpha_Y_Array[0].assign(0);	    int segmentStart = 0;	    int segmentEnd = -1;	    boolean invalid = false;	    for (int i = 0; i < dataSeq.length(); i++) {		if (segmentEnd < i) {		    segmentStart = i;		    segmentEnd = dataSeq.getSegmentEnd(i);		}		if (segmentEnd-segmentStart+1 > featureGenNested.maxMemory()) {		    if (icall == 0) {			System.out.println("Ignoring record with segment length greater than maxMemory " + numRecord);		    }		    invalid = true;		    break;		}		alpha_Y_Array[i-base].assign(RobustMath.LOG0);		for (int ell = 1; (ell <= featureGenNested.maxMemory()) && (i-ell >= base); ell++) {		    // compute the Mi matrix		    featureGenNested.startScanFeaturesAt(dataSeq, i-ell,i);		    // if (!featureGenNested.hasNext())		    //	break;		    initMDone = computeLogMi(featureGenNested,lambda,Mi_YY,Ri_Y,false,reuseM,initMDone);		    // find features that fire at this position..		    featureGenNested.startScanFeaturesAt(dataSeq, i-ell,i);		    boolean isSegment = ((i-ell+1==segmentStart) && (i == segmentEnd));		    while (featureGenNested.hasNext()) { 			Feature feature = featureGenNested.next();			int f = feature.index();			int yp = feature.y();			int yprev = feature.yprev();			float val = feature.value();			boolean allEllMatch = isSegment && (dataSeq.y(i) == yp);			if (allEllMatch && (((i-ell >= 0) && (yprev == dataSeq.y(i-ell))) || (yprev < 0))) {			    grad[f] += val;			    thisSeqLogli += val*lambda[f];			}			if ((yprev < 0) && (i-ell >= 0)) {			    for (yprev = 0; yprev < Mi_YY.rows(); yprev++) 				ExpF[f] = RobustMath.logSumExp(ExpF[f], (alpha_Y_Array[i-ell-base].get(yprev)+Ri_Y.get(yp)+Mi_YY.get(yprev,yp) + RobustMath.log(val)+beta_Y[i].get(yp)));			} else if (i-ell < 0) {			    ExpF[f] = RobustMath.logSumExp(ExpF[f], (Ri_Y.get(yp)+RobustMath.log(val)+beta_Y[i].get(yp)));			} else {			    ExpF[f] = RobustMath.logSumExp(ExpF[f], (alpha_Y_Array[i-ell-base].get(yprev)+Ri_Y.get(yp)+Mi_YY.get(yprev,yp)+RobustMath.log(val)+beta_Y[i].get(yp)));			}		    }		    if (i-ell >= 0) {			RobustMath.logMult(Mi_YY, alpha_Y_Array[i-ell-base],tmp_Y,1,0,true,edgeGen);			tmp_Y.assign(Ri_Y,sumFunc);			RobustMath.logSumExp(alpha_Y_Array[i-base],tmp_Y);		    } else {			RobustMath.logSumExp(alpha_Y_Array[i-base],Ri_Y);		    }		}		if (params.debugLvl > 2) {		    System.out.println("Alpha-i " + alpha_Y_Array[i-base].toString());		    System.out.println("Ri " + Ri_Y.toString());		    System.out.println("Mi " + Mi_YY.toString());		    System.out.println("Beta-i " + beta_Y[i].toString());		}		if (params.debugLvl > 1) {		    System.out.println(" pos "  + i + " " + thisSeqLogli);		}	    }	    if (invalid)		continue;	    double lZx = RobustMath.logSumExp(alpha_Y_Array[dataSeq.length()-1-base]);	    thisSeqLogli -= lZx;	    logli += thisSeqLogli;	    // update grad.	    for (int f = 0; f < grad.length; f++)		grad[f] -= RobustMath.exp(ExpF[f]-lZx);	    	    if (params.debugLvl > 1) {		System.out.println("Sequence "  + thisSeqLogli + " " + logli + " " + Math.exp(lZx));		System.out.println("Last Alpha-i " + alpha_Y_Array[dataSeq.length()-1-base].toString());	    }	}	if (params.debugLvl > 2) {	    for (int f = 0; f < lambda.length; f++)		System.out.print(lambda[f] + " ");	    System.out.println(" :x");	    for (int f = 0; f < lambda.length; f++)		System.out.print(grad[f] + " ");	    System.out.println(" :g");	}		if (params.debugLvl > 0)	    Util.printDbg("Iter " + icall + " loglikelihood "+logli + " gnorm " + norm(grad) + " xnorm "+ norm(lambda));	return logli;	} catch (Exception e) {	    e.printStackTrace();	    System.exit(0);	}	return 0;    }};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -