📄 trainableestimator.java
字号:
*/ public void trainTokenModel(String token, String tag, String tagMinus1, String tokenMinus1) { // CONTEXT NODE BACKOFF // Tag, Tag-1, W-1 nodeTagTag1W1 nodeTagTag1 // Tag, Tag-1 nodeTagTag1 nodeTag // Tag nodeTag -null- if (tag == null || token == null) return; Node nodeTag = mRootTokenNode.getOrCreateChild(tag,null,mTagSymbolTable); nodeTag.incrementOutcome(token,mTokenSymbolTable); if (tagMinus1 == null) return; Node nodeTagTag1 = nodeTag.getOrCreateChild(tagMinus1,nodeTag,mTagSymbolTable); nodeTagTag1.incrementOutcome(token,mTokenSymbolTable); if (tokenMinus1 == null) return; Node nodeTagTag1W1 = nodeTagTag1.getOrCreateChild(tokenMinus1, nodeTagTag1,mTokenSymbolTable); nodeTagTag1W1.incrementOutcome(token,mTokenSymbolTable); } /** * Train the tag model with the specified events. Values * supplied may be <code>null</code>, in which case only * the non-<code>null</code> events and contexts are trained. * * @param tag Tag outcome. * @param tagMinus1 Tag assigned to previous token. * @param tokenMinus1 Previous token. * @param tokenMinus2 Token occurring two tokens back. */ public void trainTagModel(String tag, String tagMinus1, String tokenMinus1, String tokenMinus2) { // CONTEXT NODE BACKOFF // Tag-1, W-1, W-2 nodeTag1W1W2 nodeTag1W1 // Tag-1, W-1 nodeTag1W1 nodeTag1 // Tag-1 nodeTag1 -null- if (tag == null || tagMinus1 == null) return; Node nodeTag1 = mRootTagNode.getOrCreateChild(tagMinus1,null,mTagSymbolTable); nodeTag1.incrementOutcome(tag,mTagSymbolTable); if (tokenMinus1 == null) return; Node nodeTag1W1 = nodeTag1.getOrCreateChild(tokenMinus1, nodeTag1,mTokenSymbolTable); nodeTag1W1.incrementOutcome(tag,mTagSymbolTable); if (tokenMinus2 == null) return; Node nodeTag1W1W2 = nodeTag1W1.getOrCreateChild(tokenMinus2, nodeTag1W1,mTokenSymbolTable); nodeTag1W1W2.incrementOutcome(tag,mTagSymbolTable); } /** * Train the token outcome for the tag in the token model. * Only the context for the tag gets incremented; more specific * contexts are not affected. * * @param token Token outcome to train. * @param tag Tag outcome to train. */ public void trainTokenOutcome(String token, String tag) { trainTokenModel(token,tag,null,null); } /** * Returns the number of nodes in the tag model. * * @return Number of nodes in the tag model. */ public int numTagNodes() { return mRootTagNode.numNodes(); } /** * Returns the number of outcomes in the tag model. * * @return Number of outcomes in the tag model. */ public int numTagOutcomes() { return mRootTagNode.numCounters(); } /** * Returns the number of nodes in the token model. * * @return Number of nodes in the token model. */ public int numTokenNodes() { return mRootTokenNode.numNodes(); } /** * Returns the number of outcomes in the token model. * * @return Number of outcomes in the token model. */ public int numTokenOutcomes() { return mRootTokenNode.numCounters(); } /** * Prune the models to the specified thresholds in terms of number * of training events for a node required to maintain a node in * the model. * * @param thresholdTag Minimum number of training events to * preserve a node in the tag model. * @param thresholdToken Minimum number of training events to * preserve a node in the token model. */ public void prune(int thresholdTag, int thresholdToken) { mRootTagNode.prune(thresholdTag); mRootTokenNode.prune(thresholdToken); } /** * Smoothes the tag model by adding the specified count to every * legal transition. That is, the count is addeed to each * estimate of <code>P(Tag2|Tag1)</code> where * <code>Tag1,Tag2</code> is a legal sequence. This makes sure * that every legal transition is possible in the output, even if * it wasn't seen in the input. A higher count increases the * degree of smoothing by moving the estimate of tag sequences * closer to uniform; this makes tags for which there was little * or no training data more likely than they would be with just * the smoothed maximum likelihood estimates. There will be no * transition data for tags added only through a dictionary * * <P>Legality of a sequence is defined by {@link * Tags#illegalSequence(String,String)}. * * @param countToAdd Count to add to each legal sequence. */ public void smoothTags(int countToAdd) { // mTagSymbolTable.add(Tags.OUT); String[] tags = mTagSymbolTable.symbols(); for (int i = 0; i < tags.length; ++i) { String tag1 = tags[i]; for (int j = 0; j < tags.length; ++j) { String tag2 = tags[j]; if (Tags.illegalSequence(tag1,tag2)) continue; for (int k = 0; k < countToAdd; ++k) { trainTagModel(tag2,tag1,null,null); } } } } /** * Writes an estimator picked out by a specified root node to * the specified data output stream. * * @param rootNode Node to write to data output stream. */ private void writeEstimator(Node rootNode, ObjectOutput out) throws IOException { rootNode.compileEstimates(mLambdaFactor); indexNodes(rootNode); out.writeInt(rootNode.numNodes()); writeNodes(rootNode,out); out.writeInt(rootNode.numCounters()); writeOutcomes(rootNode,out); } /** * Writes an integer index on each node, following a breadth-first * walk of the trie structure. * */ private static void indexNodes(Node rootNode) { LinkedList nodeQueue = new LinkedList(); nodeQueue.addLast(rootNode); int index = 0; while (nodeQueue.size() > 0) { Node node = (Node) nodeQueue.removeFirst(); node.setIndex(index++); Iterator children = node.children().iterator(); while (children.hasNext()) nodeQueue.addLast(node.getChild(children.next().toString())); } } /** * Writes the nodes in the estimator to the specified data output * stream. * * @param out Data output stream to which symbol table is written. * @throws IOException If there is an exception on the underlying * output stream. */ private static void writeNodes(Node rootNode, ObjectOutput out) throws IOException { LinkedList nodeQueue = new LinkedList(); nodeQueue.addLast(new Object[] {rootNode,null} ); int outcomesIndex = 0; int index = 0; while (nodeQueue.size() > 0) { // OUTPUT format per node. Nodes output in breadth-first // search order using a queue. // int, int, int, // symbolID, firstOutcomeIndex, firstChildIndex, // float, int // oneMinusLambda, backoffNodeIndex Object[] pair = (Object[])nodeQueue.removeFirst(); Node node = (Node) pair[0]; out.writeInt(node.getSymbolID()); out.writeInt(outcomesIndex); outcomesIndex += node.outcomes().size(); TreeSet children = new TreeSet(node.children()); if (children.size() == 0) { out.writeInt(index); } else { Iterator childIterator = children.iterator(); Node firstChild = node.getChild(childIterator.next().toString()); out.writeInt(firstChild.index()); index = firstChild.index() + node.children().size(); childIterator = children.iterator(); while (childIterator.hasNext()) { String childName = childIterator.next().toString(); Node childNode = node.getChild(childName); nodeQueue.addLast(new Object[] {childNode,childName}); } } out.writeFloat(node.oneMinusLambda()); out.writeInt(node.backoffNode() == null ? -1 : node.backoffNode().index()); } } /** * Writes the outcomes in the estimator to the specified data * output stream. * * @param out Data output stream to which symbol table is written. * @throws IOException If there is an exception on the underlying * output stream. */ private static void writeOutcomes(Node rootNode, ObjectOutput out) throws IOException { LinkedList nodeQueue = new LinkedList(); nodeQueue.addLast(rootNode); while (nodeQueue.size() > 0) { Node node = (Node) nodeQueue.removeFirst(); Iterator outcomesIterator = node.outcomes().iterator(); while (outcomesIterator.hasNext()) { String outcome = outcomesIterator.next().toString(); OutcomeCounter outcomeCounter = node.getOutcome(outcome); out.writeInt(outcomeCounter.getSymbolID()); out.writeFloat(outcomeCounter.estimate()); } Iterator childIt = node.children().iterator(); while (childIt.hasNext()) nodeQueue.addLast(node.getChild(childIt.next().toString())); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -