⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crf3.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
//=======				logger.info ("CRF finished one iteration of minimizer, i="+i);//>>>>>>> 1.6			} catch (IllegalArgumentException e) {				e.printStackTrace();				logger.info ("Catching exception; saying converged.");				converged = true;			}			if (eval != null) {				continueTraining = eval.evaluate (this, (converged || i == numIterations-1), i,																					converged, mc.getCost(), ilist, validation, testing);				if (!continueTraining)					break;			}			if (converged) {				logger.info ("CRF training has converged, i="+i);				break;			}		}		logger.info ("About to setTrainable(false)");		// Free the memory of the expectations and constraints		setTrainable (false);		logger.info ("Done setTrainable(false)");		return converged;	}	public boolean train (InstanceList training, InstanceList validation, InstanceList testing,												TransducerEvaluator eval, int numIterations,												int numIterationsPerProportion,												double[] trainingProportions)	{		int trainingIteration = 0;		for (int i = 0; i < trainingProportions.length; i++) {			// Train the CRF			InstanceList theTrainingData = training;			if (trainingProportions != null && i < trainingProportions.length) {				logger.info ("Training on "+trainingProportions[i]+"% of the data this round.");				InstanceList[] sampledTrainingData = training.split (new Random(1),																															new double[] {trainingProportions[i],	1-trainingProportions[i]});				theTrainingData = sampledTrainingData[0];			}			boolean converged = this.train (theTrainingData, validation, testing, eval, numIterationsPerProportion);			trainingIteration += numIterationsPerProportion;		}		logger.info ("Training on 100% of the data this round, for "+												(numIterations-trainingIteration)+" iterations.");		return this.train (training, validation, testing,											 eval, numIterations - trainingIteration);	}	public boolean trainWithFeatureInduction (InstanceList trainingData,																						InstanceList validationData, InstanceList testingData,																						TransducerEvaluator eval, int numIterations,																						int numIterationsBetweenFeatureInductions,																						int numFeatureInductions,																						int numFeaturesPerFeatureInduction,																						double trueLabelProbThreshold,																						boolean clusteredFeatureInduction,																						double[] trainingProportions)	{		return trainWithFeatureInduction (trainingData, validationData, testingData,																			eval, numIterations, numIterationsBetweenFeatureInductions,																			numFeatureInductions, numFeaturesPerFeatureInduction,																			trueLabelProbThreshold, clusteredFeatureInduction,																			trainingProportions, "exp");	}		public boolean trainWithFeatureInduction (InstanceList trainingData,																						InstanceList validationData, InstanceList testingData,																						TransducerEvaluator eval, int numIterations,																						int numIterationsBetweenFeatureInductions,																						int numFeatureInductions,																						int numFeaturesPerFeatureInduction,																						double trueLabelProbThreshold,																						boolean clusteredFeatureInduction,																						double[] trainingProportions,																						String gainName)	{		int trainingIteration = 0;		int numLabels = outputAlphabet.size();		this.globalFeatureSelection = trainingData.getFeatureSelection();		if (this.globalFeatureSelection == null) {			// Mask out all features; some will be added later by FeatureInducer.induceFeaturesFor(.)			this.globalFeatureSelection = new FeatureSelection (trainingData.getDataAlphabet());			trainingData.setFeatureSelection (this.globalFeatureSelection);		}		if (validationData != null) validationData.setFeatureSelection (this.globalFeatureSelection);		if (testingData != null) testingData.setFeatureSelection (this.globalFeatureSelection);				for (int featureInductionIteration = 0;				 featureInductionIteration < numFeatureInductions;				 featureInductionIteration++)		{			// Print out some feature information			logger.info ("Feature induction iteration "+featureInductionIteration);			// Train the CRF			InstanceList theTrainingData = trainingData;			if (trainingProportions != null && featureInductionIteration < trainingProportions.length) {				logger.info ("Training on "+trainingProportions[featureInductionIteration]+"% of the data this round.");				InstanceList[] sampledTrainingData = trainingData.split (new Random(1),																																	new double[] {trainingProportions[featureInductionIteration],																																								1-trainingProportions[featureInductionIteration]});				theTrainingData = sampledTrainingData[0];				theTrainingData.setFeatureSelection (this.globalFeatureSelection); // xxx necessary?				logger.info ("  which is "+theTrainingData.size()+" instances");			}			boolean converged = false;			if (featureInductionIteration != 0)				// Don't train until we have added some features				converged = this.train (theTrainingData, validationData, testingData,																eval, numIterationsBetweenFeatureInductions);			trainingIteration += numIterationsBetweenFeatureInductions;			// xxx Remove this next line			this.print ();			logger.info ("Starting feature induction with "+inputAlphabet.size()+" features.");						// Create the list of error tokens, for both unclustered and clustered feature induction			InstanceList errorInstances = new InstanceList (trainingData.getDataAlphabet(),																											trainingData.getTargetAlphabet());			// This errorInstances.featureSelection will get examined by FeatureInducer,			// so it can know how to add "new" singleton features			errorInstances.setFeatureSelection (this.globalFeatureSelection);			ArrayList errorLabelVectors = new ArrayList();			InstanceList clusteredErrorInstances[][] = new InstanceList[numLabels][numLabels];			ArrayList clusteredErrorLabelVectors[][] = new ArrayList[numLabels][numLabels];						for (int i = 0; i < numLabels; i++)				for (int j = 0; j < numLabels; j++) {					clusteredErrorInstances[i][j] = new InstanceList (trainingData.getDataAlphabet(),																														trainingData.getTargetAlphabet());					clusteredErrorInstances[i][j].setFeatureSelection (this.globalFeatureSelection);					clusteredErrorLabelVectors[i][j] = new ArrayList();				}						for (int i = 0; i < theTrainingData.size(); i++) {				logger.info ("instance="+i);				Instance instance = theTrainingData.getInstance(i);				Sequence input = (Sequence) instance.getData();				Sequence trueOutput = (Sequence) instance.getTarget();				assert (input.size() == trueOutput.size());				System.out.println("Fuchun Peng: " + theTrainingData.getTargetAlphabet());				Transducer.Lattice lattice = this.forwardBackward (input, null, false,																													 (LabelAlphabet)theTrainingData.getTargetAlphabet());				int prevLabelIndex = 0;					// This will put extra error instances in this cluster				for (int j = 0; j < trueOutput.size(); j++) {					Label label = (Label) ((LabelSequence)trueOutput).getLabelAtPosition(j);					assert (label != null);					//System.out.println ("Instance="+i+" position="+j+" fv="+lattice.getLabelingAtPosition(j).toString(true));					LabelVector latticeLabeling = lattice.getLabelingAtPosition(j);					double trueLabelProb = latticeLabeling.value(label.getIndex());					int labelIndex = latticeLabeling.getBestIndex();					//System.out.println ("position="+j+" trueLabelProb="+trueLabelProb);					if (trueLabelProb < trueLabelProbThreshold) {						logger.info ("Adding error: instance="+i+" position="+j+" prtrue="+trueLabelProb+												 (label == latticeLabeling.getBestLabel() ? "  " : " *")+												 " truelabel="+label+												 " predlabel="+latticeLabeling.getBestLabel()+												 " fv="+((FeatureVector)input.get(j)).toString(true));						errorInstances.add (input.get(j), label, null, null);						errorLabelVectors.add (latticeLabeling);						clusteredErrorInstances[prevLabelIndex][labelIndex].add (input.get(j), label, null, null);						clusteredErrorLabelVectors[prevLabelIndex][labelIndex].add (latticeLabeling);					}					prevLabelIndex = labelIndex;				}			}			logger.info ("Error instance list size = "+errorInstances.size());			if (clusteredFeatureInduction) {				FeatureInducer[][] klfi = new FeatureInducer[numLabels][numLabels];				for (int i = 0; i < numLabels; i++) {					for (int j = 0; j < numLabels; j++) {						// Note that we may see some "impossible" transitions here (like O->I in a OIB model)						// because we are using lattice gammas to get the predicted label, not Viterbi.						// I don't believe this does any harm, and may do some good.						logger.info ("Doing feature induction for "+												 outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j)+												 " with "+clusteredErrorInstances[i][j].size()+" instances");						if (clusteredErrorInstances[i][j].size() < 20) {							logger.info ("..skipping because only "+clusteredErrorInstances[i][j].size()+" instances.");							continue;						}						int s = clusteredErrorLabelVectors[i][j].size();						LabelVector[] lvs = new LabelVector[s];						for (int k = 0; k < s; k++)							lvs[k] = (LabelVector) clusteredErrorLabelVectors[i][j].get(k);						RankedFeatureVector.Factory gainFactory = null;						if (gainName.equals ("exp"))							gainFactory = new ExpGain.Factory (lvs, gaussianPriorVariance);						else if (gainName.equals("grad"))							gainFactory =	new GradientGain.Factory (lvs);						else if (gainName.equals("info"))							gainFactory =	new InfoGain.Factory ();						klfi[i][j] = new FeatureInducer (gainFactory,																						 clusteredErrorInstances[i][j], 																						 numFeaturesPerFeatureInduction,																						 2*numFeaturesPerFeatureInduction,																						 2*numFeaturesPerFeatureInduction);						featureInducers.add(klfi[i][j]);					}				}				for (int i = 0; i < numLabels; i++) {					for (int j = 0; j < numLabels; j++) {						logger.info ("Adding new induced features for "+												 outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j));						if (klfi[i][j] == null) {							logger.info ("...skipping because no features induced.");							continue;						}						// Note that this adds features globally, but not on a per-transition basis						klfi[i][j].induceFeaturesFor (trainingData, false, false);						if (testingData != null) klfi[i][j].induceFeaturesFor (testingData, false, false);					}				}				klfi = null;			} else {				int s = errorLabelVectors.size();				LabelVector[] lvs = new LabelVector[s];				for (int i = 0; i < s; i++)					lvs[i] = (LabelVector) errorLabelVectors.get(i);				RankedFeatureVector.Factory gainFactory = null;				if (gainName.equals ("exp"))					gainFactory = new ExpGain.Factory (lvs, gaussianPriorVariance);				else if (gainName.equals("grad"))					gainFactory =	new GradientGain.Factory (lvs);				else if (gainName.equals("info"))					gainFactory =	new InfoGain.Factory ();				FeatureInducer klfi =					new FeatureInducer (gainFactory,															errorInstances, 															numFeaturesPerFeatureInduction,															2*numFeaturesPerFeatureInduction,															2*numFeaturesPerFeatureInduction);				featureInducers.add(klfi);				// Note that this adds features globally, but not on a per-transition basis				klfi.induceFeaturesFor (trainingData, false, false);				if (testingData != null) klfi.induceFeaturesFor (testingData, false, false);				logger.info ("CRF3 FeatureSelection now includes "+this.globalFeatureSelection.cardinality()+" features");				klfi = null;			}			// This is done in CRF3.train() anyway			//this.setWeightsDimensionAsIn (trainingData);			////this.growWeightsDimensionToInputAlphabet ();		}		return this.train (trainingData, validationData, testingData,											 eval, numIterations - trainingIteration);	}	/** This method is deprecated. */	public Sequence[] predict (InstanceList testing) {		testing.setFeatureSelection(this.globalFeatureSelection);		for (int i = 0; i < featureInducers.size(); i++) {			FeatureInducer klfi = (FeatureInducer)featureInducers.get(i);			klfi.induceFeaturesFor (testing, false, false);		}		Sequence[] ret = new Sequence[testing.size()];		for (int i = 0; i < testing.size(); i++) {			Instance instance = testing.getInstance(i);			Sequence input = (Sequence) instance.getData();			Sequence trueOutput = (Sequence) instance.getTarget();			assert (input.size() == trueOutput.size());			Sequence predOutput = viterbiPath(input).output();			assert (predOutput.size() == trueOutput.size());			ret[i] = predOutput;		}		return ret;	}		/** This method is deprecated. */	public void evaluate (TransducerEvaluator eval, InstanceList testing) {		testing.setFeatureSelection(this.globalFeatureSelection);		for (int i = 0; i < featureInducers.size(); i++) {			FeatureInducer klfi = (FeatureInducer)featureInducers.get(i);			klfi.induceFeaturesFor (testing, false, false);		}		eval.evaluate (this, true, 0, true, 0.0, null, null, testing);	}	public void write (File f) {		try {			ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));			oos.writeObject(this);			oos.close();		}		catch (IOException e) {			System.err.println("Exception writing file " + f + ": " + e);		}	}		public MinimizableCRF getMinimizableCRF (InstanceList ilist)	{		return new MinimizableCRF (ilist, this);	}	// Serialization	// For CRF class	private static final long serialVersionUID = 1;	private static final int CURRENT_SERIAL_VERSION = 1;	static final int NULL_INTEGER = -1;	/* Need to check for null pointers. */	private void writeObject (ObjectOutputStream out) throws IOException {		int i, size;		out.writeInt (CURRENT_SERIAL_VERSION);		out.writeObject(inputPipe);		out.writeObject(outputPipe);		out.writeObject (inputAlphabet);		out.writeObject (outputAlphabet);		size = states.size();		out.writeInt(size);		for (i = 0; i<size; i++)			out.writeObject(states.get(i));		size = initialStates.size();		out.writeInt(size);		for (i = 0; i <size; i++)			out.writeObject(initialStates.get(i));		out.writeObject(name2state);		if(weights != null) {			size = weights.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeObject(weights[i]);		}	else {			out.writeInt(NULL_INTEGER);		}		if(constraints != null) {			size = constraints.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeObject(constraints[i]);		}	else {			out.writeInt(NULL_INTEGER);		}		if (expectations != null) {			size = expectations.length;			out.writeInt(size);			for (i=0; i<size; i++)				out.writeObject(expectations[i]);		} else {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -