⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crf4.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
		int numLabels = outputAlphabet.size();		this.globalFeatureSelection = trainingData.getFeatureSelection();		if (this.globalFeatureSelection == null) {			// Mask out all features; some will be added later by FeatureInducer.induceFeaturesFor(.)			this.globalFeatureSelection = new FeatureSelection (trainingData.getDataAlphabet());			trainingData.setFeatureSelection (this.globalFeatureSelection);		}		if (validationData != null) validationData.setFeatureSelection (this.globalFeatureSelection);		if (testingData != null) testingData.setFeatureSelection (this.globalFeatureSelection);				for (int featureInductionIteration = 0;				 featureInductionIteration < numFeatureInductions;				 featureInductionIteration++)		{			// Print out some feature information			logger.info ("Feature induction iteration "+featureInductionIteration);			// Train the CRF			InstanceList theTrainingData = trainingData;			if (trainingProportions != null && featureInductionIteration < trainingProportions.length) {				logger.info ("Training on "+trainingProportions[featureInductionIteration]+"% of the data this round.");				InstanceList[] sampledTrainingData = trainingData.split (new Random(1),																																	new double[] {trainingProportions[featureInductionIteration],																																								1-trainingProportions[featureInductionIteration]});				theTrainingData = sampledTrainingData[0];				theTrainingData.setFeatureSelection (this.globalFeatureSelection); // xxx necessary?				logger.info ("  which is "+theTrainingData.size()+" instances");			}			boolean converged = false;			if (featureInductionIteration != 0)				// Don't train until we have added some features				converged = this.train (theTrainingData, validationData, testingData,																eval, numIterationsBetweenFeatureInductions);			trainingIteration += numIterationsBetweenFeatureInductions;			// xxx Remove this next line			this.print ();			logger.info ("Starting feature induction with "+inputAlphabet.size()+" features.");						// Create the list of error tokens, for both unclustered and clustered feature induction			InstanceList errorInstances = new InstanceList (trainingData.getDataAlphabet(),																											trainingData.getTargetAlphabet());			// This errorInstances.featureSelection will get examined by FeatureInducer,			// so it can know how to add "new" singleton features			errorInstances.setFeatureSelection (this.globalFeatureSelection);			ArrayList errorLabelVectors = new ArrayList();			InstanceList clusteredErrorInstances[][] = new InstanceList[numLabels][numLabels];			ArrayList clusteredErrorLabelVectors[][] = new ArrayList[numLabels][numLabels];						for (int i = 0; i < numLabels; i++)				for (int j = 0; j < numLabels; j++) {					clusteredErrorInstances[i][j] = new InstanceList (trainingData.getDataAlphabet(),																														trainingData.getTargetAlphabet());					clusteredErrorInstances[i][j].setFeatureSelection (this.globalFeatureSelection);					clusteredErrorLabelVectors[i][j] = new ArrayList();				}						for (int i = 0; i < theTrainingData.size(); i++) {				logger.info ("instance="+i);				Instance instance = theTrainingData.getInstance(i);				Sequence input = (Sequence) instance.getData();				Sequence trueOutput = (Sequence) instance.getTarget();				assert (input.size() == trueOutput.size());				Transducer.Lattice lattice = this.forwardBackward (input, null, false,																													 (LabelAlphabet)theTrainingData.getTargetAlphabet());				int prevLabelIndex = 0;					// This will put extra error instances in this cluster				for (int j = 0; j < trueOutput.size(); j++) {					Label label = (Label) ((LabelSequence)trueOutput).getLabelAtPosition(j);					assert (label != null);					//System.out.println ("Instance="+i+" position="+j+" fv="+lattice.getLabelingAtPosition(j).toString(true));					LabelVector latticeLabeling = lattice.getLabelingAtPosition(j);					double trueLabelProb = latticeLabeling.value(label.getIndex());					int labelIndex = latticeLabeling.getBestIndex();					//System.out.println ("position="+j+" trueLabelProb="+trueLabelProb);					if (trueLabelProb < trueLabelProbThreshold) {						logger.info ("Adding error: instance="+i+" position="+j+" prtrue="+trueLabelProb+												 (label == latticeLabeling.getBestLabel() ? "  " : " *")+												 " truelabel="+label+												 " predlabel="+latticeLabeling.getBestLabel()+												 " fv="+((FeatureVector)input.get(j)).toString(true));						errorInstances.add (input.get(j), label, null, null);						errorLabelVectors.add (latticeLabeling);						clusteredErrorInstances[prevLabelIndex][labelIndex].add (input.get(j), label, null, null);						clusteredErrorLabelVectors[prevLabelIndex][labelIndex].add (latticeLabeling);					}					prevLabelIndex = labelIndex;				}			}			logger.info ("Error instance list size = "+errorInstances.size());			if (clusteredFeatureInduction) {				FeatureInducer[][] klfi = new FeatureInducer[numLabels][numLabels];				for (int i = 0; i < numLabels; i++) {					for (int j = 0; j < numLabels; j++) {						// Note that we may see some "impossible" transitions here (like O->I in a OIB model)						// because we are using lattice gammas to get the predicted label, not Viterbi.						// I don't believe this does any harm, and may do some good.						logger.info ("Doing feature induction for "+												 outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j)+												 " with "+clusteredErrorInstances[i][j].size()+" instances");						if (clusteredErrorInstances[i][j].size() < 20) {							logger.info ("..skipping because only "+clusteredErrorInstances[i][j].size()+" instances.");							continue;						}						int s = clusteredErrorLabelVectors[i][j].size();						LabelVector[] lvs = new LabelVector[s];						for (int k = 0; k < s; k++)							lvs[k] = (LabelVector) clusteredErrorLabelVectors[i][j].get(k);						RankedFeatureVector.Factory gainFactory = null;						if (gainName.equals ("exp"))							gainFactory = new ExpGain.Factory (lvs, gaussianPriorVariance);						else if (gainName.equals("grad"))							gainFactory =	new GradientGain.Factory (lvs);						else if (gainName.equals("info"))							gainFactory =	new InfoGain.Factory ();						klfi[i][j] = new FeatureInducer (gainFactory,																						 clusteredErrorInstances[i][j], 																						 numFeaturesPerFeatureInduction,																						 2*numFeaturesPerFeatureInduction,																						 2*numFeaturesPerFeatureInduction);						featureInducers.add(klfi[i][j]);					}				}				for (int i = 0; i < numLabels; i++) {					for (int j = 0; j < numLabels; j++) {						logger.info ("Adding new induced features for "+												 outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j));						if (klfi[i][j] == null) {							logger.info ("...skipping because no features induced.");							continue;						}						// Note that this adds features globally, but not on a per-transition basis						klfi[i][j].induceFeaturesFor (trainingData, false, false);						if (testingData != null) klfi[i][j].induceFeaturesFor (testingData, false, false);					}				}				klfi = null;			} else {				int s = errorLabelVectors.size();				LabelVector[] lvs = new LabelVector[s];				for (int i = 0; i < s; i++)					lvs[i] = (LabelVector) errorLabelVectors.get(i);				RankedFeatureVector.Factory gainFactory = null;				if (gainName.equals ("exp"))					gainFactory = new ExpGain.Factory (lvs, gaussianPriorVariance);				else if (gainName.equals("grad"))					gainFactory =	new GradientGain.Factory (lvs);				else if (gainName.equals("info"))					gainFactory =	new InfoGain.Factory ();				FeatureInducer klfi =					new FeatureInducer (gainFactory,															errorInstances, 															numFeaturesPerFeatureInduction,															2*numFeaturesPerFeatureInduction,															2*numFeaturesPerFeatureInduction);				featureInducers.add(klfi);				// Note that this adds features globally, but not on a per-transition basis				klfi.induceFeaturesFor (trainingData, false, false);				if (testingData != null) klfi.induceFeaturesFor (testingData, false, false);				logger.info ("CRF4 FeatureSelection now includes "+this.globalFeatureSelection.cardinality()+" features");				klfi = null;			}			// This is done in CRF4.train() anyway			//this.setWeightsDimensionAsIn (trainingData);			////this.growWeightsDimensionToInputAlphabet ();		}		return this.train (trainingData, validationData, testingData,											 eval, numIterations - trainingIteration);	}	/** This method is deprecated. */	public Sequence[] predict (InstanceList testing) {		testing.setFeatureSelection(this.globalFeatureSelection);		for (int i = 0; i < featureInducers.size(); i++) {			FeatureInducer klfi = (FeatureInducer)featureInducers.get(i);			klfi.induceFeaturesFor (testing, false, false);		}		Sequence[] ret = new Sequence[testing.size()];		for (int i = 0; i < testing.size(); i++) {			Instance instance = testing.getInstance(i);			Sequence input = (Sequence) instance.getData();			Sequence trueOutput = (Sequence) instance.getTarget();			assert (input.size() == trueOutput.size());			Sequence predOutput = viterbiPath(input).output();			assert (predOutput.size() == trueOutput.size());			ret[i] = predOutput;		}		return ret;	}		/** This method is deprecated. */	public void evaluate (TransducerEvaluator eval, InstanceList testing) {		testing.setFeatureSelection(this.globalFeatureSelection);		for (int i = 0; i < featureInducers.size(); i++) {			FeatureInducer klfi = (FeatureInducer)featureInducers.get(i);			klfi.induceFeaturesFor (testing, false, false);		}		eval.evaluate (this, true, 0, true, 0.0, null, null, testing);	}	public void write (File f) {		try {			ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));			oos.writeObject(this);			oos.close();		}		catch (IOException e) {			System.err.println("Exception writing file " + f + ": " + e);		}	}		public MaximizableCRF getMaximizableCRF (InstanceList ilist)	{		return new MaximizableCRF (ilist, this);	}	// Serialization	// For CRF class	private static final long serialVersionUID = 1;  // Serial versions  //  3: Add transduction type.  //  4: Add weightsFrozen	private static final int CURRENT_SERIAL_VERSION = 4;	static final int NULL_INTEGER = -1;	/* Need to check for null pointers. */	private void writeObject (ObjectOutputStream out) throws IOException {		int i, size;		out.writeInt (CURRENT_SERIAL_VERSION);		out.writeObject(inputPipe);		out.writeObject(outputPipe);		out.writeObject (inputAlphabet);		out.writeObject (outputAlphabet);		size = states.size();		out.writeInt(size);		for (i = 0; i<size; i++)			out.writeObject(states.get(i));		size = initialStates.size();		out.writeInt(size);		for (i = 0; i <size; i++)			out.writeObject(initialStates.get(i));		out.writeObject(name2state);		if(weights != null) {			size = weights.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeObject(weights[i]);		}	else {			out.writeInt(NULL_INTEGER);		}		if(constraints != null) {			size = constraints.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeObject(constraints[i]);		}	else {			out.writeInt(NULL_INTEGER);		}		if (expectations != null) {			size = expectations.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeObject(expectations[i]);		} else {			out.writeInt(NULL_INTEGER);		}				if(defaultWeights != null) {			size = defaultWeights.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeDouble(defaultWeights[i]);		}	else {			out.writeInt(NULL_INTEGER);		}		if(defaultConstraints != null) {			size = defaultConstraints.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeDouble(defaultConstraints[i]);		}	else {			out.writeInt(NULL_INTEGER);		}		if (defaultExpectations != null) {			size = defaultExpectations.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeDouble(defaultExpectations[i]);		}	else {			out.writeInt(NULL_INTEGER);		}				if (weightsPresent != null) {			size = weightsPresent.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeObject(weightsPresent[i]);		}	else {			out.writeInt(NULL_INTEGER);		}		if (featureSelections != null) {			size = featureSelections.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeObject(featureSelections[i]);		} else {			out.writeInt(NULL_INTEGER);		}    if (weightsFrozen != null) {      size = weightsFrozen.length;      out.writeInt (size);      for (i = 0; i < size; i++)       out.writeBoolean (weightsFrozen[i]);    } else {      out.writeInt (NULL_INTEGER);    }		out.writeObject(globalFeatureSelection);		out.writeObject(weightAlphabet);		out.writeBoolean(trainable);		out.writeBoolean(gatheringConstraints);		out.writeBoolean(gatheringWeightsPresent);		//out.writeInt(defaultFeatureIndex);		out.writeBoolean(usingHyperbolicPrior);		out.writeDouble(gaussianPriorVariance);		out.writeDouble(hyperbolicPriorSlope);		out.writeDouble(hyperbolicPriorSharpness);		out.writeBoolean(cachedValueStale);		out.writeBoolean(cachedGradientStale);		out.writeBoolean(someTrainingDone);		out.writeInt(featureInducers.size());		for (i = 0; i < featureInducers.size(); i++) {			out.writeObject(featureInducers.get(i));		}		out.writeBoolean(printGradient);		out.writeBoolean (use

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -