📄 crf.java
字号:
destinationNames[k] = labels[j]+','+labels[k]; addState (labels[i]+','+labels[j], 0.0, 0.0, destinationNames, labels); } } } /** Add states to create a second-order Markov model on labels, adding only those transitions the occur in the given trainingSet. */ public void addStatesForBiLabelsConnectedAsIn (InstanceList trainingSet) { int numLabels = outputAlphabet.size(); boolean[][] connections = labelConnectionsIn (trainingSet); for (int i = 0; i < numLabels; i++) { for (int j = 0; j < numLabels; j++) { if (!connections[i][j]) continue; int numDestinations = 0; for (int k = 0; k < numLabels; k++) if (connections[j][k]) numDestinations++; String[] destinationNames = new String[numDestinations]; String[] labels = new String[numDestinations]; int destinationIndex = 0; for (int k = 0; k < numLabels; k++) if (connections[j][k]) { destinationNames[destinationIndex] = (String)outputAlphabet.lookupObject(j)+','+(String)outputAlphabet.lookupObject(k); labels[destinationIndex] = (String)outputAlphabet.lookupObject(k); destinationIndex++; } addState ((String)outputAlphabet.lookupObject(i)+','+ (String)outputAlphabet.lookupObject(j), 0.0, 0.0, destinationNames, labels); } } } public void addFullyConnectedStatesForTriLabels () { String[] labels = new String[outputAlphabet.size()]; // This is assuming the the entries in the outputAlphabet are Strings! for (int i = 0; i < outputAlphabet.size(); i++) { logger.info ("CRF: outputAlphabet.lookup class = "+ outputAlphabet.lookupObject(i).getClass().getName()); labels[i] = (String) outputAlphabet.lookupObject(i); } for (int i = 0; i < labels.length; i++) { for (int j = 0; j < labels.length; j++) { for (int k = 0; k < labels.length; k++) { String[] destinationNames = new String[labels.length]; for (int l = 0; l < labels.length; l++) destinationNames[l] = labels[j]+','+labels[k]+','+labels[l]; addState (labels[i]+','+labels[j]+','+labels[k], 0.0, 0.0, destinationNames, labels); } } } } public void addSelfTransitioningStateForAllLabels (String name) { String[] labels = new String[outputAlphabet.size()]; String[] destinationNames = new String[outputAlphabet.size()]; // This is assuming the the entries in the outputAlphabet are Strings! for (int i = 0; i < outputAlphabet.size(); i++) { logger.info ("CRF: outputAlphabet.lookup class = "+ outputAlphabet.lookupObject(i).getClass().getName()); labels[i] = (String) outputAlphabet.lookupObject(i); destinationNames[i] = name; } addState (name, 0.0, 0.0, destinationNames, labels); } public void setWeights (int weightsIndex, DenseVector transitionWeights) { if (transitionWeights.singleSize() != (defaultFeatureIndex+1)) throw new IllegalArgumentException ("Vector transitionWeights has incorrect size = " + transitionWeights.singleSize()); if (weightsIndex >= weights.length || weightsIndex < 0) throw new IllegalArgumentException ("weightsIndex "+weightsIndex+" is out of bounds"); weights[weightsIndex] = transitionWeights; } public void setWeights (String weightName, DenseVector transitionWeights) { setWeights (getWeightsIndex (weightName), transitionWeights); } public String getWeightsName (int weightIndex) { return (String) weightAlphabet.lookupObject (weightIndex); } public DenseVector getWeights (String weightName) { return weights[getWeightsIndex (weightName)]; } public DenseVector getWeights (int weightIndex) { return weights[weightIndex]; } /** Increase the size of the weights[] parameters to match (a new, larger) input Alphabet size */ public void growWeightsDimensionToInputAlphabet () { int vs = inputAlphabet.size(); if (vs == this.defaultFeatureIndex) // Doesn't need to grow return; assert (vs > this.defaultFeatureIndex); setTrainable (false); for (int i = 0; i < weights.length; i++) { DenseVector newWeights = new DenseVector (vs+1); newWeights.arrayCopyFrom (0, weights[i]); newWeights.setValue (vs, weights[i].value (defaultFeatureIndex)); newWeights.setValue (defaultFeatureIndex, 0); weights[i] = newWeights; } this.defaultFeatureIndex = vs; cachedCostStale = true; cachedGradientStale = true; } // Create a new weight Vector if weightName is new. public int getWeightsIndex (String weightName) { int wi = weightAlphabet.lookupIndex (weightName); if (wi == -1) throw new IllegalArgumentException ("Alphabet frozen, and no weight with name "+ weightName); if (weights == null) { assert (wi == 0); weights = new DenseVector[1]; weights[0] = new DenseVector ((defaultFeatureIndex+1)); setTrainable (false); } else if (wi == weights.length) { DenseVector[] newWeights = new DenseVector[weights.length+1]; for (int i = 0; i < weights.length; i++) newWeights[i] = weights[i]; newWeights[wi] = new DenseVector ((defaultFeatureIndex+1)); weights = newWeights; setTrainable (false); } return wi; } public int numStates () { return states.size(); } public Transducer.State getState (int index) { return (Transducer.State) states.get(index); } public Iterator initialStateIterator () { return initialStates.iterator (); } public boolean isTrainable () { return trainable; } public void setTrainable (boolean f) { if (f != trainable) { if (f) { constraints = new DenseVector[weights.length]; expectations = new DenseVector[weights.length]; for (int i = 0; i < weights.length; i++) { constraints[i] = new DenseVector (weights[i].singleSize()); expectations[i] = new DenseVector (weights[i].singleSize()); } } else constraints = expectations = null; for (int i = 0; i < numStates(); i++) ((State)getState(i)).setTrainable(f); trainable = f; } } public void setParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value) { cachedCostStale = cachedGradientStale = true; State source = (State)getState(sourceStateIndex); State dest = (State) getState(destStateIndex); int rowIndex; for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++) if (source.destinationNames[rowIndex].equals (dest.name)) break; if (rowIndex == source.destinationNames.length) throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+"."); DenseVector v = weights[source.weightsIndices[rowIndex]]; v.setValue (featureIndex, value); } public double getParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value) { State source = (State)getState(sourceStateIndex); State dest = (State) getState(destStateIndex); int rowIndex; for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++) if (source.destinationNames[rowIndex].equals (dest.name)) break; if (rowIndex == source.destinationNames.length) throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+"."); DenseVector v = weights[source.weightsIndices[rowIndex]]; return v.value (featureIndex); } public void reset () { throw new UnsupportedOperationException ("Not used in CRFs"); } public void estimate () { if (!trainable) throw new IllegalStateException ("This transducer not currently trainable."); // xxx Put stuff in here. throw new UnsupportedOperationException ("Not yet implemented. Never?"); } public void print () { StringBuffer sb = new StringBuffer(); for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); sb.append (s.name); sb.append (" ("); sb.append (s.destinations.length); sb.append (" outgoing transitions)\n"); sb.append (" "); sb.append ("initialCost = "); sb.append (s.initialCost); sb.append ('\n'); sb.append (" "); sb.append ("finalCost = "); sb.append (s.finalCost); sb.append ('\n'); for (int j = 0; j < s.destinations.length; j++) { sb.append (" -> "); sb.append (s.destinations[j].name); sb.append ('\n'); DenseVector transitionWeights = weights[s.weightsIndices[j]]; RankedFeatureVector rfv = new RankedFeatureVector (inputAlphabet, transitionWeights); for (int k = 0; k < transitionWeights.singleSize(); k++) { double v = rfv.getValueAtRank(k); int index = rfv.getIndexAtRank(k); Object feature = index == defaultFeatureIndex ? "<DEFAULT_FEATURE>" : inputAlphabet.lookupObject (index); if (v != 0) { sb.append (" "); sb.append (s.name); sb.append (" -> "); sb.append (s.destinations[j].name); sb.append (": "); sb.append (feature); sb.append (" = "); sb.append (v); sb.append ('\n'); } } } } System.out.println (sb.toString()); } // Java question: // If I make a non-static inner class CRF.Trainer, // can that class by subclassed in another .java file, // and can that subclass still have access to all the CRF's // instance variables? public boolean train (InstanceList ilist) { return train (ilist, (InstanceList)null, (InstanceList)null); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing) { return train (ilist, validation, testing, (TransducerEvaluator)null); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval) { return train (ilist, validation, testing, eval, 9999); } public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations) { if (numIterations <= 0) return false; assert (ilist.size() > 0); MinimizableCRF mc = new MinimizableCRF (ilist, this); //Minimizer.ByGradient minimizer = new ConjugateGradient (0.001); Minimizer.ByGradient minimizer = new LimitedMemoryBFGS(); int i; boolean continueTraining = true; boolean converged = false; for (i = 0; i < numIterations; i++) { try { converged = minimizer.minimize (mc, 1); } catch (IllegalArgumentException e) { e.printStackTrace(); logger.info ("Catching exception; saying converged."); converged = true; } if (eval != null) { continueTraining = eval.evaluate (this, (converged || i == numIterations-1), i, converged, mc.getCost(), ilist, validation, testing); if (!continueTraining) break; } if (converged) break; } logger.info ("About to setTrainable(false)"); // Free the memory of the expectations and constraints setTrainable (false); logger.info ("Done setTrainable(false)"); return converged; } public boolean train (InstanceList training, InstanceList validation, InstanceList testing, TransducerEvaluator eval, int numIterations, int numIterationsPerProportion, double[] trainingProportions) { int trainingIteration = 0; for (int i = 0; i < trainingProportions.length; i++) { // Train the CRF InstanceList theTrainingData = training; if (trainingProportions != null && i < trainingProportions.length) { logger.info ("Training on "+trainingProportions[i]+"% of the data this round."); InstanceList[] sampledTrainingData = training.split (new Random(1), new double[] {trainingProportions[i], 1-trainingProportions[i]}); theTrainingData = sampledTrainingData[0]; } boolean converged = this.train (theTrainingData, validation, testing, eval, numIterationsPerProportion); trainingIteration += numIterationsPerProportion; } logger.info ("Training on 100% of the data this round, for "+ (numIterations-trainingIteration)+" iterations."); return this.train (training, validation, testing, eval, numIterations - trainingIteration); } public boolean trainWithFeatureInduction (InstanceList trainingData, InstanceList validationData, InstanceList testingData,
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -