📄 crfbygisupdate.java
字号:
for (int k = 0; k < labels.length; k++) destinationNames[k] = labels[j]+','+labels[k]; addState (labels[i]+','+labels[j], 0.0, 0.0, destinationNames, labels); } } } /** Add states to create a second-order Markov model on labels, adding only those transitions the occur in the given trainingSet. */ public void addStatesForBiLabelsConnectedAsIn (InstanceList trainingSet) { int numLabels = outputAlphabet.size(); boolean[][] connections = labelConnectionsIn (trainingSet); for (int i = 0; i < numLabels; i++) { for (int j = 0; j < numLabels; j++) { if (!connections[i][j]) continue; int numDestinations = 0; for (int k = 0; k < numLabels; k++) if (connections[j][k]) numDestinations++; String[] destinationNames = new String[numDestinations]; String[] labels = new String[numDestinations]; int destinationIndex = 0; for (int k = 0; k < numLabels; k++) if (connections[j][k]) { destinationNames[destinationIndex] = (String)outputAlphabet.lookupObject(j)+','+(String)outputAlphabet.lookupObject(k); labels[destinationIndex] = (String)outputAlphabet.lookupObject(k); destinationIndex++; } addState ((String)outputAlphabet.lookupObject(i)+','+ (String)outputAlphabet.lookupObject(j), 0.0, 0.0, destinationNames, labels); } } } public void addFullyConnectedStatesForTriLabels () { String[] labels = new String[outputAlphabet.size()]; // This is assuming the the entries in the outputAlphabet are Strings! for (int i = 0; i < outputAlphabet.size(); i++) { logger.info ("CRF: outputAlphabet.lookup class = "+ outputAlphabet.lookupObject(i).getClass().getName()); labels[i] = (String) outputAlphabet.lookupObject(i); } for (int i = 0; i < labels.length; i++) { for (int j = 0; j < labels.length; j++) { for (int k = 0; k < labels.length; k++) { String[] destinationNames = new String[labels.length]; for (int l = 0; l < labels.length; l++) destinationNames[l] = labels[j]+','+labels[k]+','+labels[l]; addState (labels[i]+','+labels[j]+','+labels[k], 0.0, 0.0, destinationNames, labels); } } } } public void addSelfTransitioningStateForAllLabels (String name) { String[] labels = new String[outputAlphabet.size()]; String[] destinationNames = new String[outputAlphabet.size()]; // This is assuming the the entries in the outputAlphabet are Strings! for (int i = 0; i < outputAlphabet.size(); i++) { System.out.println ("CRF: outputAlphabet.lookup class = "+ outputAlphabet.lookupObject(i).getClass().getName()); labels[i] = (String) outputAlphabet.lookupObject(i); destinationNames[i] = name; } addState (name, 0.0, 0.0, destinationNames, labels); } public State getState (String name) { return (State) name2state.get(name); } public void setWeights (int weightsIndex, SparseVector transitionWeights) { cachedCostStale = cachedGISUpdateStale = true; if (weightsIndex >= weights.length || weightsIndex < 0) throw new IllegalArgumentException ("weightsIndex "+weightsIndex+" is out of bounds"); weights[weightsIndex] = transitionWeights; } public void setWeights (String weightName, SparseVector transitionWeights) { setWeights (getWeightsIndex (weightName), transitionWeights); } public String getWeightsName (int weightIndex) { return (String) weightAlphabet.lookupObject (weightIndex); } public SparseVector getWeights (String weightName) { return weights[getWeightsIndex (weightName)]; } public SparseVector getWeights (int weightIndex) { return weights[weightIndex]; } // Methods added by Ryan McDonald // Purpose is for AGIS-Limited Memory Experiments // Allows one to train on AGIS for N iterations, and then // copy weights to begin training on Limited-Memory for the // rest. public SparseVector[] getWeights () { return weights; } public void setWeights (SparseVector[] m) { weights = m; } public void setWeightsDimensionAsIn (InstanceList trainingData) { // The cost doesn't actually change, because the "new" parameters will have zero value // but the gradient changes because the parameters now have different layout. cachedCostStale = cachedGISUpdateStale = true; setTrainable (false); weightsPresent = new BitSet[weights.length]; for (int i = 0; i < weights.length; i++) weightsPresent[i] = new BitSet(); gatheringWeightsPresent = true; // Put in the weights that are already there for (int i = 0; i < weights.length; i++) for (int j = weights[i].numLocations()-1; j >= 0; j--) weightsPresent[i].set (weights[i].indexAtLocation(j)); // Put in the weights in the training set if (this.someTrainingDone) System.err.println("Some training done previously"); for (int i = 0; i < trainingData.size(); i++) { Instance instance = trainingData.getInstance(i); FeatureVectorSequence input = (FeatureVectorSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); // Do it for the paths consistent with the labels... gatheringConstraints = true; forwardBackward (input, output, true); // ...and also do it for the paths selected by the current model (so we will get some negative weights) gatheringConstraints = false; if (this.someTrainingDone) // (do this once some training is done) forwardBackward (input, null, true); } gatheringWeightsPresent = false; SparseVector[] newWeights = new SparseVector[weights.length]; for (int i = 0; i < weights.length; i++) { int numLocations = weightsPresent[i].cardinality (); System.out.println ("CRF weights["+weightAlphabet.lookupObject(i)+"] num features = "+numLocations); int[] indices = new int[numLocations]; for (int j = 0; j < numLocations; j++) { indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1); //System.out.println ("CRFByGISUpdate has index "+indices[j]); } newWeights[i] = new SparseVector (indices, new double[numLocations], numLocations, numLocations, false, false, false); newWeights[i].plusEqualsSparse (weights[i]); } weights = newWeights; } /** Increase the size of the weights[] parameters to match (a new, larger) input Alphabet size */ // No longer needed /* public void growWeightsDimensionToInputAlphabet () { int vs = inputAlphabet.size(); if (vs == this.defaultFeatureIndex) // Doesn't need to grow return; assert (vs > this.defaultFeatureIndex); setTrainable (false); for (int i = 0; i < weights.length; i++) { DenseVector newWeights = new DenseVector (vs+1); newWeights.arrayCopyFrom (0, weights[i]); newWeights.setValue (vs, weights[i].value (defaultFeatureIndex)); newWeights.setValue (defaultFeatureIndex, 0); weights[i] = newWeights; } this.defaultFeatureIndex = vs; cachedCostStale = true; cachedGISUpdateStale = true; } */ // Create a new weight Vector if weightName is new. public int getWeightsIndex (String weightName) { int wi = weightAlphabet.lookupIndex (weightName); if (wi == -1) throw new IllegalArgumentException ("Alphabet frozen, and no weight with name "+ weightName); if (weights == null) { assert (wi == 0); weights = new SparseVector[1]; defaultWeights = new double[1]; featureSelections = new FeatureSelection[1]; // Use initial capacity of 8 weights[0] = new SparseVector (); defaultWeights[0] = 0; featureSelections[0] = null; } else if (wi == weights.length) { SparseVector[] newWeights = new SparseVector[weights.length+1]; double[] newDefaultWeights = new double[weights.length+1]; FeatureSelection[] newFeatureSelections = new FeatureSelection[weights.length+1]; for (int i = 0; i < weights.length; i++) { newWeights[i] = weights[i]; newDefaultWeights[i] = defaultWeights[i]; newFeatureSelections[i] = featureSelections[i]; } newWeights[wi] = new SparseVector (); newDefaultWeights[wi] = 0; newFeatureSelections[wi] = null; weights = newWeights; defaultWeights = newDefaultWeights; featureSelections = newFeatureSelections; } setTrainable (false); return wi; } public int numStates () { return states.size(); } public Transducer.State getState (int index) { return (Transducer.State) states.get(index); } public Iterator initialStateIterator () { return initialStates.iterator (); } public boolean isTrainable () { return trainable; } public void setTrainable (boolean f) { if (f != trainable) { if (f) { constraints = new SparseVector[weights.length]; expectations = new SparseVector[weights.length]; defaultConstraints = new double[weights.length]; defaultExpectations = new double[weights.length]; for (int i = 0; i < weights.length; i++) { constraints[i] = (SparseVector) weights[i].cloneMatrixZeroed (); expectations[i] = (SparseVector) weights[i].cloneMatrixZeroed (); } } else { constraints = expectations = null; defaultConstraints = defaultExpectations = null; } for (int i = 0; i < numStates(); i++) ((State)getState(i)).setTrainable(f); trainable = f; } } public double getParametersAbsNorm () { double ret = 0; for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); ret += Math.abs (s.initialCost); ret += Math.abs (s.finalCost); } for (int i = 0; i < weights.length; i++) { ret += Math.abs (defaultWeights[i]); ret += weights[i].absNorm(); } return ret; } /** Only sets the parameter from the first group of parameters. */ public void setParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value) { cachedCostStale = cachedGISUpdateStale = true; State source = (State)getState(sourceStateIndex); State dest = (State) getState(destStateIndex); int rowIndex; for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++) if (source.destinationNames[rowIndex].equals (dest.name)) break; if (rowIndex == source.destinationNames.length) throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+"."); int weightsIndex = source.weightsIndices[rowIndex][0]; if (featureIndex < 0) defaultWeights[weightsIndex] = value; else { weights[weightsIndex].setValue (featureIndex, value); } someTrainingDone = true; } /** Only gets the parameter from the first group of parameters. */ public double getParameter (int sourceStateIndex, int destStateIndex, int featureIndex, double value) { State source = (State)getState(sourceStateIndex); State dest = (State) getState(destStateIndex); int rowIndex; for (rowIndex = 0; rowIndex < source.destinationNames.length; rowIndex++) if (source.destinationNames[rowIndex].equals (dest.name)) break; if (rowIndex == source.destinationNames.length) throw new IllegalArgumentException ("No transtition from state "+sourceStateIndex+" to state "+destStateIndex+"."); int weightsIndex = source.weightsIndices[rowIndex][0]; if (featureIndex < 0) return defaultWeights[weightsIndex]; else return weights[weightsIndex].value (featureIndex); } public void reset () { throw new UnsupportedOperationException ("Not used in CRFs"); } public void estimate () { if (!trainable) throw new IllegalStateException ("This transducer not currently trainable."); // xxx Put stuff in here. throw new UnsupportedOperationException ("Not yet implemented. Never?"); } // yyy public void print () { StringBuffer sb = new StringBuffer(); for (int i = 0; i < numStates(); i++) { State s = (State) getState (i); sb.append ("STATE NAME=\""); sb.append (s.name); sb.append ("\" ("); sb.append (s.destinations.length); sb.append (" outgoing transitions)\n"); sb.append (" "); sb.append ("initialCost = "); sb.append (s.initialCost); sb.append ('\n'); sb.append (" "); sb.append ("finalCost = "); sb.append (s.finalCost); sb.append ('\n'); for (int j = 0; j < s.destinations.length; j++) { sb.append (" -> "); sb.append (s.getDestinationState(j).getName()); for (int k = 0; k < s.weightsIndices[j].length; k++) { sb.append (" WEIGHTS NAME=\""); sb.append (weightAlphabet.lookupObject(s.weightsIndices[j][k]).toString()); sb.append ("\"\n"); sb.append (" "); sb.append (s.name); sb.append (" -> "); sb.append (s.destinations[j].name); sb.append (": "); sb.append ("<DEFAULT_FEATURE> = "); sb.append (defaultWeights[s.weightsIndices[j][k]]); sb.append('\n'); SparseVector transitionWeights = weights[s.weightsIndices[j][k]]; if (transitionWeights.numLocations() == 0) continue; RankedFeatureVector rfv = new RankedFeatureVector (inputAlphabet, transitionWeights); for (int m = 0; m < rfv.numLocations(); m++) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -