📄 crf4.java
字号:
assert defaultWeights != null; assert featureSelections != null; assert weightsFrozen != null; int n = weights.length; assert defaultWeights.length == n; assert featureSelections.length == n; assert weightsFrozen.length == n; } } public int numStates () { return states.size(); } public Transducer.State getState (int index) { return (Transducer.State) states.get(index); } public Iterator initialStateIterator () { return initialStates.iterator (); } public boolean isTrainable () { return trainable; } public void setTrainable (boolean f) { if (f != trainable) { if (f) { constraints = new SparseVector[weights.length]; expectations = new SparseVector[weights.length]; defaultConstraints = new double[weights.length]; defaultExpectations = new double[weights.length]; for (int i = 0; i < weights.length; i++) { constraints[i] = (SparseVector) weights[i].cloneMatrixZeroed (); expectations[i] = (SparseVector) weights[i].cloneMatrixZeroed (); } } else { constraints = expectations = null; defaultConstraints = defaultExpectations = null; } for (int i = 0; i < numStates(); i++) ((State)getState(i)).setTrainable(f); trainable = f; } } public double getParametersAbsNorm () { double ret = 0; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); ret += Math.abs (s.initialCost); ret += Math.abs (s.finalCost); } for (int i = 0; i < weights.length; i++) { ret += Math.abs (defaultWeights[i]); ret += weights[i].absNorm(); } return ret; } /** Only sets the parameter from the first group of parameters. */ public void setParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value) { cachedValueStale = cachedGradientStale = true; State source = (State)getState(sourceStateIndex); State dest = (State) getState(destStateIndex); int rowIndex; for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++) if (source.destinationNames[rowIndex].equals (dest.name)) break; if (rowIndex == source.destinationNames.length) throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+"."); int weightsIndex = source.weightsIndices[rowIndex][0]; if (featureIndex < 0) defaultWeights[weightsIndex] = value; else { weights[weightsIndex].setValue (featureIndex, value); } someTrainingDone = true; } /** Only gets the parameter from the first group of parameters. */ public double getParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value) { State source = (State)getState(sourceStateIndex); State dest = (State) getState(destStateIndex); int rowIndex; for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++) if (source.destinationNames[rowIndex].equals (dest.name)) break; if (rowIndex == source.destinationNames.length) throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+"."); int weightsIndex = source.weightsIndices[rowIndex][0]; if (featureIndex < 0) return defaultWeights[weightsIndex]; else return weights[weightsIndex].value (featureIndex); } public void reset () { throw new UnsupportedOperationException ("Not used in CRFs"); } public void estimate () { if (!trainable) throw new IllegalStateException ("This transducer not currently trainable."); // xxx Put stuff in here. throw new UnsupportedOperationException ("Not yet implemented. Never?"); } //xxx experimental public Sequence transduce (Object unpipedInput) { Instance inst = new Instance (unpipedInput, null, null, null, inputPipe); return transduce ((FeatureVectorSequence) inst.getData ()); } public Sequence transduce (Sequence input) { if (!(input instanceof FeatureVectorSequence)) throw new IllegalArgumentException ("CRF4.transduce requires FeatureVectorSequence. This may be relaxed later..."); switch (transductionType) { case VITERBI: ViterbiPath lattice = viterbiPath (input); return lattice.output (); // CPAL - added viterbi "beam search" case VITERBI_FBEAM: ViterbiPathBeam lattice2 = viterbiPathBeam(input); return lattice2.output(); //CPAL - added viterbi backward beam search case VITERBI_BBEAM: ViterbiPathBeamB lattice3 = viterbiPathBeamB(input); return lattice3.output(); // CPAL - added viterbi forward backward beam // - we may call this constrained "max field" or something similar later case VITERBI_FBBEAM: ViterbiPathBeamFB lattice4 = viterbiPathBeamFB(input); return lattice4.output(); // CPAL - added an adaptive viterbi "beam search" case VITERBI_FBEAMKL: ViterbiPathBeamKL lattice5 = viterbiPathBeamKL(input); return lattice5.output(); default: throw new IllegalStateException ("Unknown CRF4 transuction type "+transductionType); } } public void print () { print (new PrintWriter (new OutputStreamWriter (System.out), true)); } // yyy public void print (PrintWriter out) { out.println ("*** CRF STATES ***"); for (int i = 0; i < numStates (); i++) { State s = (State) getState (i); out.print ("STATE NAME=\""); out.print (s.name); out.print ("\" ("); out.print (s.destinations.length); out.print (" outgoing transitions)\n"); out.print (" "); out.print ("initialCost = "); out.print (s.initialCost); out.print ('\n'); out.print (" "); out.print ("finalCost = "); out.print (s.finalCost); out.print ('\n'); out.println (" transitions:"); for (int j = 0; j < s.destinations.length; j++) { out.print (" "); out.print (s.name); out.print (" -> "); out.println (s.getDestinationState (j).getName ()); for (int k = 0; k < s.weightsIndices[j].length; k++) { out.print (" WEIGHTS = \""); int widx = s.weightsIndices[j][k]; out.print (weightAlphabet.lookupObject (widx).toString ()); out.print ("\"\n"); } } out.println (); } out.println ("\n\n\n*** CRF WEIGHTS ***"); for (int widx = 0; widx < weights.length; widx++) { out.println ("WEIGHTS NAME = " + weightAlphabet.lookupObject (widx)); out.print (": <DEFAULT_FEATURE> = "); out.print (defaultWeights[widx]); out.print ('\n'); SparseVector transitionWeights = weights[widx]; if (transitionWeights.numLocations () == 0) continue; RankedFeatureVector rfv = new RankedFeatureVector (inputAlphabet, transitionWeights); for (int m = 0; m < rfv.numLocations (); m++) { double v = rfv.getValueAtRank (m); int index = rfv.indexAtLocation (rfv.getIndexAtRank (m)); Object feature = inputAlphabet.lookupObject (index); if (v != 0) { out.print (": "); out.print (feature); out.print (" = "); out.println (v); } } } out.flush (); } // Java question: // If I make a non-static inner class CRF.Trainer, // can that class by subclassed in another .java file, // and can that subclass still have access to all the CRF's // instance variables? // ANSWER: Yes and yes, but you have to use special syntax in the subclass ctor (see mallet-dev archive) -cas public boolean train (InstanceList ilist) { return train (ilist, (InstanceList)null, (InstanceList)null); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing) { return train (ilist, validation, testing, (TransducerEvaluator)null); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval) { return train (ilist, validation, testing, eval, 9999); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations) { if (numIterations <= 0) return false; assert (ilist.size() > 0); if (useSparseWeights) { setWeightsDimensionAsIn (ilist); } else { setWeightsDimensionDensely (); } MaximizableCRF mc = new MaximizableCRF (ilist, this); //Maximizer.ByGradient minimizer = new ConjugateGradient (0.001); Maximizer.ByGradient maximizer = new LimitedMemoryBFGS(); int i; boolean continueTraining = true; boolean converged = false; logger.info ("CRF about to train with "+numIterations+" iterations"); for (i = 0; i < numIterations; i++) { try { // CPAL - added this to alter forward backward beam parameters based on iteration setCurIter(i); // CPAL - this resets the tctIter as well converged = maximizer.maximize (mc, 1); logger.info ("CRF took " + tctIter + " intermediate iterations"); //if (i!=0) { // if(tctIter>1) { // increase the beam size // } //} // CPAL - done logger.info ("CRF finished one iteration of maximizer, i="+i); } catch (IllegalArgumentException e) { e.printStackTrace(); logger.info ("Catching exception; saying converged."); converged = true; } if (eval != null) { continueTraining = eval.evaluate (this, (converged || i == numIterations-1), i, converged, mc.getValue(), ilist, validation, testing); if (!continueTraining) break; } if (converged) { logger.info ("CRF training has converged, i="+i); break; } } logger.info ("About to setTrainable(false)"); // Free the memory of the expectations and constraints setTrainable (false); logger.info ("Done setTrainable(false)"); return converged; } public boolean train (InstanceList training, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations, int numIterationsPerProportion, double[] trainingProportions) { int trainingIteration = 0; for (int i = 0; i < trainingProportions.length; i++) { // Train the CRF InstanceList theTrainingData = training; if (trainingProportions != null && i < trainingProportions.length) { logger.info ("Training on "+trainingProportions[i]+"% of the data this round."); InstanceList[] sampledTrainingData = training.split (new Random(1), new double[] {trainingProportions[i], 1-trainingProportions[i]}); theTrainingData = sampledTrainingData[0]; } boolean converged = this.train (theTrainingData, validation, testing, eval, numIterationsPerProportion); trainingIteration += numIterationsPerProportion; } logger.info ("Training on 100% of the data this round, for "+ (numIterations-trainingIteration)+" iterations."); return this.train (training, validation, testing, eval, numIterations - trainingIteration); } public boolean trainWithFeatureInduction (InstanceList trainingData, InstanceList validationData, InstanceList testingData, TransducerEvaluator eval, int numIterations, int numIterationsBetweenFeatureInductions, int numFeatureInductions, int numFeaturesPerFeatureInduction, double trueLabelProbThreshold, boolean clusteredFeatureInduction, double[] trainingProportions) { return trainWithFeatureInduction (trainingData, validationData, testingData, eval, numIterations, numIterationsBetweenFeatureInductions, numFeatureInductions, numFeaturesPerFeatureInduction, trueLabelProbThreshold, clusteredFeatureInduction, trainingProportions, "exp"); } public boolean trainWithFeatureInduction (InstanceList trainingData, InstanceList validationData, InstanceList testingData, TransducerEvaluator eval, int numIterations, int numIterationsBetweenFeatureInductions, int numFeatureInductions, int numFeaturesPerFeatureInduction, double trueLabelProbThreshold, boolean clusteredFeatureInduction, double[] trainingProportions, String gainName) { int trainingIteration = 0;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -