⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crf2.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
																															new double[] {trainingProportions[i],	1-trainingProportions[i]});				theTrainingData = sampledTrainingData[0];			}			boolean converged = this.train (theTrainingData, validation, testing, eval, numIterationsPerProportion);			trainingIteration += numIterationsPerProportion;		}		logger.info("Training on 100% of the data this round, for "+												(numIterations-trainingIteration)+" iterations.");		return this.train (training, validation, testing,											 eval, numIterations - trainingIteration);	}	public boolean trainWithFeatureInduction (InstanceList trainingData,																						InstanceList validationData, InstanceList testingData,																						TransducerEvaluator eval, int numIterations,																						int numIterationsBetweenFeatureInductions,																						int numFeatureInductions,																						int numFeaturesPerFeatureInduction,																						double trueLabelProbThreshold,																						boolean clusteredFeatureInduction,																						double[] trainingProportions)	{		int trainingIteration = 0;		int numLabels = outputAlphabet.size();				for (int featureInductionIteration = 0;				 featureInductionIteration < numFeatureInductions;				 featureInductionIteration++)		{			// Print out some feature information			logger.info("Feature induction iteration "+featureInductionIteration);			// Train the CRF			InstanceList theTrainingData = trainingData;			if (trainingProportions != null && featureInductionIteration < trainingProportions.length) {				logger.info("Training on "+trainingProportions[featureInductionIteration]+"% of the data this round.");				InstanceList[] sampledTrainingData = trainingData.split (new Random(1),																																	new double[] {trainingProportions[featureInductionIteration],																																								1-trainingProportions[featureInductionIteration]});				theTrainingData = sampledTrainingData[0];			}			boolean converged = this.train (theTrainingData, validationData, testingData, eval, numIterationsBetweenFeatureInductions);			trainingIteration += numIterationsBetweenFeatureInductions;			logger.info("Starting feature induction with "+inputAlphabet.size()+" features.");						// Create the list of error tokens, for both unclustered and clustered feature induction			InstanceList errorInstances = new InstanceList (trainingData.getDataAlphabet(),																											trainingData.getTargetAlphabet());			ArrayList errorLabelVectors = new ArrayList();			InstanceList clusteredErrorInstances[][] = new InstanceList[numLabels][numLabels];			ArrayList clusteredErrorLabelVectors[][] = new ArrayList[numLabels][numLabels];						for (int i = 0; i < numLabels; i++)				for (int j = 0; j < numLabels; j++) {					clusteredErrorInstances[i][j] = new InstanceList (trainingData.getDataAlphabet(),																														 trainingData.getTargetAlphabet());					clusteredErrorLabelVectors[i][j] = new ArrayList();				}						for (int i = 0; i < theTrainingData.size(); i++) {				logger.info("instance="+i);				Instance instance = theTrainingData.getInstance(i);				Sequence input = (Sequence) instance.getData();				Sequence trueOutput = (Sequence) instance.getTarget();				assert (input.size() == trueOutput.size());				Transducer.Lattice lattice = this.forwardBackward (input, null, false,																													 (LabelAlphabet)theTrainingData.getTargetAlphabet());				int prevLabelIndex = 0;					// This will put extra error instances in this cluster				for (int j = 0; j < trueOutput.size(); j++) {					Label label = (Label) ((LabelSequence)trueOutput).getLabelAtPosition(j);					assert (label != null);					//System.out.println ("Instance="+i+" position="+j+" fv="+lattice.getLabelingAtPosition(j).toString(true));					LabelVector latticeLabeling = lattice.getLabelingAtPosition(j);					double trueLabelProb = latticeLabeling.value(label.getIndex());					int labelIndex = latticeLabeling.getBestIndex();					//System.out.println ("position="+j+" trueLabelProb="+trueLabelProb);					if (trueLabelProb < trueLabelProbThreshold) {						logger.info("Adding error: instance="+i+" position="+j+" prtrue="+trueLabelProb+																(label == latticeLabeling.getBestLabel() ? "  " : " *")+																" truelabel="+label+																" predlabel="+latticeLabeling.getBestLabel()+																" fv="+((FeatureVector)input.get(j)).toString(true));						errorInstances.add (input.get(j), label, null, null);						errorLabelVectors.add (latticeLabeling);						clusteredErrorInstances[prevLabelIndex][labelIndex].add (input.get(j), label, null, null);						clusteredErrorLabelVectors[prevLabelIndex][labelIndex].add (latticeLabeling);					}					prevLabelIndex = labelIndex;				}			}			logger.info("Error instance list size = "+errorInstances.size());			if (clusteredFeatureInduction) {				FeatureInducer[][] klfi = new FeatureInducer[numLabels][numLabels];				for (int i = 0; i < numLabels; i++) {					for (int j = 0; j < numLabels; j++) {						// Note that we may see some "impossible" transitions here (like O->I in a OIB model)						// because we are using lattice gammas to get the predicted label, not Viterbi.						// I don't believe this does any harm, and may do some good.						logger.info("Doing feature induction for "+																outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j)+																" with "+clusteredErrorInstances[i][j].size()+" instances");						if (clusteredErrorInstances[i][j].size() < 20) {							logger.info("..skipping because only "+clusteredErrorInstances[i][j].size()+" instances.");							continue;						}						int s = clusteredErrorLabelVectors[i][j].size();						LabelVector[] lvs = new LabelVector[s];						for (int k = 0; k < s; k++)							lvs[k] = (LabelVector) clusteredErrorLabelVectors[i][j].get(k);						klfi[i][j] = new FeatureInducer (new ExpGain.Factory (lvs),																						 clusteredErrorInstances[i][j], numFeaturesPerFeatureInduction);					}				}				for (int i = 0; i < numLabels; i++) {					for (int j = 0; j < numLabels; j++) {						logger.info("Adding new induced features for "+																outputAlphabet.lookupObject(i)+" -> "+outputAlphabet.lookupObject(j));						if (klfi[i][j] == null) {							logger.info("...skipping because no features induced.");							continue;						}						klfi[i][j].induceFeaturesFor (trainingData, false, false);						klfi[i][j].induceFeaturesFor (testingData, false, false);					}				}				klfi = null;			} else if (true) {				int s = errorLabelVectors.size();				LabelVector[] lvs = new LabelVector[s];				for (int i = 0; i < s; i++)					lvs[i] = (LabelVector) errorLabelVectors.get(i);				FeatureInducer klfi = new FeatureInducer (new ExpGain.Factory (lvs),																									errorInstances, numFeaturesPerFeatureInduction);				logger.info("Inducing features for training set.");				klfi.induceFeaturesFor (trainingData, false, false);				logger.info("Inducing features for testing set.");				klfi.induceFeaturesFor (testingData, false, false);				klfi = null;			} else {				// Currently not ever used				FeatureInducer igfi = new FeatureInducer (new InfoGain.Factory(), errorInstances,																									numFeaturesPerFeatureInduction);				igfi.induceFeaturesFor (trainingData, false, false);				igfi.induceFeaturesFor (testingData, false, false);				igfi = null;			}			// This is done in CRF2.train() anyway			//this.setWeightsDimensionAsIn (trainingData);			////this.growWeightsDimensionToInputAlphabet ();		}		return this.train (trainingData, validationData, testingData,											 eval, numIterations - trainingIteration);	}		public void write (File f) {		try {			ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));			oos.writeObject(this);			oos.close();		}		catch (IOException e) {			System.err.println("Exception writing file " + f + ": " + e);		}	}		public MinimizableCRF getMinimizableCRF (InstanceList ilist)	{		return new MinimizableCRF (ilist, this);	}	// Serialization	// For CRF class	private static final long serialVersionUID = 1;	private static final int CURRENT_SERIAL_VERSION = 0;	static final int NULL_INTEGER = -1;	/* Need to check for null pointers. */	private void writeObject (ObjectOutputStream out) throws IOException {		int i, size;		out.writeInt (CURRENT_SERIAL_VERSION);		out.writeObject (inputAlphabet);		out.writeObject (outputAlphabet);		size = states.size();		out.writeInt(size);		for (i = 0; i<size; i++) {			out.writeObject(states.get(i));		}		size = initialStates.size();		out.writeInt(size);		for (i = 0; i <size; i++) {			out.writeObject(initialStates.get(i));		}		out.writeObject(name2state);		if(weights != null) {			size = weights.length;			out.writeInt(size);			for (i=0; i<size; i++) {				out.writeObject(weights[i]);			}		}		else {			out.writeInt(NULL_INTEGER);		}		if(constraints != null) {			size = constraints.length;			out.writeInt(size);			for (i=0; i<size; i++) {				out.writeObject(constraints[i]);			}		}		else {			out.writeInt(NULL_INTEGER);		}		if (expectations != null) {			size = expectations.length;			out.writeInt(size);			for (i=0; i<size; i++) {				out.writeObject(expectations[i]);			}		}		else {			out.writeInt(NULL_INTEGER);		}		out.writeObject(weightAlphabet);		out.writeBoolean(trainable);		out.writeBoolean(gatheringConstraints);		//out.writeInt(defaultFeatureIndex);		out.writeBoolean(usingHyperbolicPrior);		out.writeDouble(gaussianPriorVariance);		out.writeDouble(hyperbolicPriorSlope);		out.writeDouble(hyperbolicPriorSharpness);		out.writeBoolean(cachedCostStale);		out.writeBoolean(cachedGradientStale);	}		private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {		int size, i;		int version = in.readInt ();		inputAlphabet = (Alphabet) in.readObject();		outputAlphabet = (Alphabet) in.readObject();		size = in.readInt();		states = new ArrayList();		for (i=0; i<size; i++) {			State s = (CRF2.State) in.readObject();			states.add(s);		}		size = in.readInt();		initialStates = new ArrayList();		for (i=0; i<size; i++) {			State s = (CRF2.State) in.readObject();			initialStates.add(s);		}		name2state = (HashMap) in.readObject();		size = in.readInt();		if (size == NULL_INTEGER) {			weights = null;		}		else {			weights = new SparseVector[size];			for(i=0; i< size; i++) {				weights[i] = (SparseVector) in.readObject();			}		}		size = in.readInt();		if (size == NULL_INTEGER) {			constraints = null;		}		else {			constraints = new SparseVector[size];			for(i=0; i< size; i++) {				constraints[i] = (SparseVector) in.readObject();			}		}		size = in.readInt();		if (size == NULL_INTEGER) {			expectations = null;		}		else {			expectations = new SparseVector[size];			for(i=0; i< size; i++) {				expectations[i] = (SparseVector)in.readObject();			}		}		weightAlphabet = (Alphabet) in.readObject();		trainable = in.readBoolean();		gatheringConstraints = in.readBoolean();		//defaultFeatureIndex = in.readInt();		usingHyperbolicPrior = in.readBoolean();		gaussianPriorVariance = in.readDouble();		hyperbolicPriorSlope = in.readDouble();		hyperbolicPriorSharpness = in.readDouble();		cachedCostStale = in.readBoolean();		cachedGradientStale = in.readBoolean();	}	public class MinimizableCRF implements Minimizable.ByGradient, Serializable	{		InstanceList trainingSet;		double cachedCost = -123456789;		DenseVector cachedGradient;		BitSet infiniteCosts = null;		int numParameters;		CRF2 crf;		protected MinimizableCRF (InstanceList ilist, CRF2 crf)		{			// Set up			this.numParameters = 2 * numStates() + defaultWeights.length;			for (int i = 0; i < weights.length; i++)				numParameters += weights[i].numLocations();			this.trainingSet = ilist;			this.crf = crf;			cachedGradient = (DenseVector) getNewMatrix ();			// This resets and values that may have been in expecations and constraints			setTrainable (true);			// Set the contraints by running forward-backward with the *output			// label sequence provided*, thus restricting it to only those			// paths that agree with the label sequence.			gatheringConstraints = true;			for (int i = 0; i < ilist.size(); i++) {				Instance instance = ilist.getInstance(i);				FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();				FeatureSequence output = (FeatureSequence) instance.getTarget();				//System.out.println ("Confidence-gathering forward-backward on instance "+i+" of "+ilist.size());				this.crf.forwardBackward (input, output, true);				//System.out.println ("Gathering constraints for Instance #"+i);			}			gatheringConstraints = false;		}		public Matrix getNewMatrix () { return new DenseVector (numParameters); }		// Negate initialCost and finalCost because the parameters are in

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -