📄 crf4.java
字号:
int numLabels = outputAlphabet.size(); while (historyIndexes[0] < numLabels) { logger.info("Preparing " + concatLabels(history)); if (allowedHistory(history, forbidden, allowed)) { String stateName = concatLabels(history); int nt = 0; String[] destNames = new String[numLabels]; String[] labelNames = new String[numLabels]; String[][] weightNames = new String[numLabels][orders.length]; for (int nextIndex = 0; nextIndex < numLabels; nextIndex++) { String next = (String)outputAlphabet.lookupObject(nextIndex); if (allowedTransition(history[order-1], next, forbidden, allowed) && (fullyConnected || connections[historyIndexes[order-1]][nextIndex])) { destNames[nt] = nextKGram(history, order, next); labelNames[nt] = next; for (int i = 0; i < orders.length; i++) { weightNames[nt][i] = nextKGram(history, orders[i]+1, next); if (defaults != null && defaults[i]) { int wi = getWeightsIndex (weightNames[nt][i]); // Using empty feature selection gives us only the // default features featureSelections[wi] = new FeatureSelection(trainingSet.getDataAlphabet()); } } nt++; } } if (nt < numLabels) { String[] newDestNames = new String[nt]; String[] newLabelNames = new String[nt]; String[][] newWeightNames = new String[nt][]; for (int t = 0; t < nt; t++) { newDestNames[t] = destNames[t]; newLabelNames[t] = labelNames[t]; newWeightNames[t] = weightNames[t]; } destNames = newDestNames; labelNames = newLabelNames; weightNames = newWeightNames; } for (int i = 0; i < destNames.length; i++) { StringBuffer b = new StringBuffer(); for (int j = 0; j < orders.length; j++) b.append(" ").append(weightNames[i][j]); logger.info(stateName + "->" + destNames[i] + "(" + labelNames[i] + ")" + b.toString()); } addState (stateName, 0.0, 0.0, destNames, labelNames, weightNames); } for (int o = order-1; o >= 0; o--) if (++historyIndexes[o] < numLabels) { history[o] = (String)outputAlphabet.lookupObject(historyIndexes[o]); break; } else if (o > 0) { historyIndexes[o] = 0; history[o] = label0; } } for (int i = 0; i < order; i++) history[i] = start; return concatLabels(history); } else { String[] stateNames = new String[outputAlphabet.size()]; for (int s = 0; s < outputAlphabet.size(); s++) stateNames[s] = (String)outputAlphabet.lookupObject(s); for (int s = 0; s < outputAlphabet.size(); s++) addState(stateNames[s], 0.0, 0.0, stateNames, stateNames, stateNames); return start; } } public State getState (String name) { return (State) name2state.get(name); } public void setWeights (int weightsIndex, SparseVector transitionWeights) { cachedValueStale = cachedGradientStale = true; if (weightsIndex >= weights.length || weightsIndex < 0) throw new IllegalArgumentException ("weightsIndex "+weightsIndex+" is out of bounds"); weights[weightsIndex] = transitionWeights; } public void setWeights (String weightName, SparseVector transitionWeights) { setWeights (getWeightsIndex (weightName), transitionWeights); } public String getWeightsName (int weightIndex) { return (String) weightAlphabet.lookupObject (weightIndex); } public SparseVector getWeights (String weightName) { return weights[getWeightsIndex (weightName)]; } public SparseVector getWeights (int weightIndex) { return weights[weightIndex]; } public double[] getDefaultWeights () { return defaultWeights; } // Methods added by Ryan McDonald // Purpose is for AGIS-Limited Memory Experiments // Allows one to train on AGIS for N iterations, and then // copy weights to begin training on Limited-Memory for the // rest. public SparseVector[] getWeights () { return weights; } public void setWeights (SparseVector[] m) { weights = m; } public void setDefaultWeights (double[] w) { defaultWeights = w; } public void setDefaultWeight (int widx, double val) { defaultWeights[widx] = val; } /** * Freezes a set of weights to their current values. * Frozen weights are used for labeling sequences (as in <tt>transduce</tt>), * but are not be modified by the <tt>train</tt> methods. * @param weightsIndex Index of weight set to freeze. */ public void freezeWeights (int weightsIndex) { weightsFrozen [weightsIndex] = true; } /** * Freezes a set of weights to their current values. * Frozen weights are used for labeling sequences (as in <tt>transduce</tt>), * but are not be modified by the <tt>train</tt> methods. * @param weightsName Name of weight set to freeze. */ public void freezeWeights (String weightsName) { int widx = getWeightsIndex (weightsName); freezeWeights (widx); } /** * Unfreezes a set of weights. * Frozen weights are used for labeling sequences (as in <tt>transduce</tt>), * but are not be modified by the <tt>train</tt> methods. * @param weightsName Name of weight set to unfreeze. */ public void unfreezeWeights (String weightsName) { int widx = getWeightsIndex (weightsName); weightsFrozen[widx] = false; } public void setFeatureSelection (int weightIdx, FeatureSelection fs) { featureSelections [weightIdx] = fs; } public void setWeightsDimensionAsIn (InstanceList trainingData) { int numWeights = 0; // The value doesn't actually change, because the "new" parameters will have zero value // but the gradient changes because the parameters now have different layout. cachedValueStale = cachedGradientStale = true; setTrainable (false); weightsPresent = new BitSet[weights.length]; for (int i = 0; i < weights.length; i++) weightsPresent[i] = new BitSet(); gatheringWeightsPresent = true; // Put in the weights that are already there for (int i = 0; i < weights.length; i++) for (int j = weights[i].numLocations()-1; j >= 0; j--) weightsPresent[i].set (weights[i].indexAtLocation(j)); // Put in the weights in the training set if (this.someTrainingDone) System.err.println("Some training done previously"); for (int i = 0; i < trainingData.size(); i++) { Instance instance = trainingData.getInstance(i); FeatureVectorSequence input = (FeatureVectorSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); // Do it for the paths consistent with the labels... gatheringConstraints = true; // ****************************************************************************************** // CPAL - Beam Version could be used here forwardBackward (input, output, true); // forwardBackwardBeam (input, output, true); // ****************************************************************************************** // ...and also do it for the paths selected by the current model (so we will get some negative weights) gatheringConstraints = false; if (this.someTrainingDone && useSomeUnsupportedTrick) { logger.info ("CRF4: Incremental training detected. Adding weights for some supported features..."); // (do this once some training is done) // ****************************************************************************************** // CPAL - Beam Version could be used here forwardBackward (input, null, true); // ****************************************************************************************** //forwardBackwardBeam (input, output, true); } } gatheringWeightsPresent = false; SparseVector[] newWeights = new SparseVector[weights.length]; for (int i = 0; i < weights.length; i++) { int numLocations = weightsPresent[i].cardinality (); logger.info ("CRF weights["+weightAlphabet.lookupObject(i)+"] num features = "+numLocations); int[] indices = new int[numLocations]; for (int j = 0; j < numLocations; j++) { indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1); //System.out.println ("CRF4 has index "+indices[j]); } newWeights[i] = new IndexedSparseVector (indices, new double[numLocations], numLocations, numLocations, false, false, false); newWeights[i].plusEqualsSparse (weights[i]); numWeights += (numLocations + 1); } logger.info("Number of weights = "+numWeights); weights = newWeights; } public void setWeightsDimensionDensely () { SparseVector[] newWeights = new SparseVector [weights.length]; int max = inputAlphabet.size(); int numWeights = 0; logger.info ("CRF using dense weights, num input features = "+max); for (int i = 0; i < weights.length; i++) { int nfeatures; if (featureSelections[i] == null) { nfeatures = max; newWeights [i] = new SparseVector (null, new double [max], max, max, false, false, false); } else { // Respect the featureSelection FeatureSelection fs = featureSelections[i]; nfeatures = fs.getBitSet ().cardinality (); int[] idxs = new int [nfeatures]; int j = 0, thisIdx = -1; while ((thisIdx = fs.nextSelectedIndex (thisIdx + 1)) >= 0) { idxs[j++] = thisIdx; } newWeights[i] = new SparseVector (idxs, new double [nfeatures], nfeatures, nfeatures, false, false, false); } newWeights [i].plusEqualsSparse (weights [i]); numWeights += (nfeatures + 1); } logger.info("Number of weights = "+numWeights); weights = newWeights; } /** Increase the size of the weights[] parameters to match (a new, larger) input Alphabet size */ // No longer needed /* public void growWeightsDimensionToInputAlphabet () { int vs = inputAlphabet.size(); if (vs == this.defaultFeatureIndex) // Doesn't need to grow return; assert (vs > this.defaultFeatureIndex); setTrainable (false); for (int i = 0; i < weights.length; i++) { DenseVector newWeights = new DenseVector (vs+1); newWeights.arrayCopyFrom (0, weights[i]); newWeights.setValue (vs, weights[i].value (defaultFeatureIndex)); newWeights.setValue (defaultFeatureIndex, 0); weights[i] = newWeights; } this.defaultFeatureIndex = vs; cachedValueStale = true; cachedGradientStale = true; } */ // Create a new weight Vector if weightName is new. public int getWeightsIndex (String weightName) { int wi = weightAlphabet.lookupIndex (weightName); if (wi == -1) throw new IllegalArgumentException ("Alphabet frozen, and no weight with name "+ weightName); if (weights == null) { assert (wi == 0); weights = new SparseVector[1]; defaultWeights = new double[1]; featureSelections = new FeatureSelection[1]; weightsFrozen = new boolean [1]; // Use initial capacity of 8 weights[0] = new IndexedSparseVector (); defaultWeights[0] = 0; featureSelections[0] = null; } else if (wi == weights.length) { SparseVector[] newWeights = new SparseVector[weights.length+1]; double[] newDefaultWeights = new double[weights.length+1]; FeatureSelection[] newFeatureSelections = new FeatureSelection[weights.length+1]; for (int i = 0; i < weights.length; i++) { newWeights[i] = weights[i]; newDefaultWeights[i] = defaultWeights[i]; newFeatureSelections[i] = featureSelections[i]; } newWeights[wi] = new IndexedSparseVector (); newDefaultWeights[wi] = 0; newFeatureSelections[wi] = null; weights = newWeights; defaultWeights = newDefaultWeights; featureSelections = newFeatureSelections; weightsFrozen = ArrayUtils.append (weightsFrozen, false); } setTrainable (false); return wi; } private void assertWeightsLength () { if (weights != null) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -