s3_astar.c

来自「CMU大名鼎鼎的SPHINX-3大词汇量连续语音识别系统」· C语言 代码 · 共 1,267 行 · 第 1/3 页

C
1,267
字号
    /* Propagate best value from below into root, if any */    l = root->left;    r = root->right;    if (! l) {	if (! r) {	    listelem_free ((char *) root, sizeof(aheap_t));	    return NULL;	} else {	    root->ppath = r->ppath;	    root->right = aheap_pop (r);	    root->nr--;	}    } else {	if ((! r) || (l->ppath->tscr >= r->ppath->tscr)) {	    root->ppath = l->ppath;	    root->left = aheap_pop (l);	    root->nl--;	} else {	    root->ppath = r->ppath;	    root->right = aheap_pop (r);	    root->nr--;	}    }    return root;}/** * Check if pplist already contains a better (better pscr) path identical to the * extension of lmhist by node.  Return 1 if true, 0 if false.  Also, if false, but * an inferior path did exist, mark it as pruned. */static int32 ppath_dup (ppath_t *hlist, ppath_t *lmhist, dagnode_t *node,			uint32 hval, int32 pscr){    ppath_t *h1, *h2;        /* Compare each entry in hlist to new, proposed path */    for (; hlist; hlist = hlist->hashnext) {	if ((hlist->dagnode != node) || (hlist->histhash != hval))	    continue;		for (h1 = hlist->lmhist, h2 = lmhist; h1 && h2; h1 = h1->lmhist, h2 = h2->lmhist) {	    if ((h1 == h2) ||	/* Histories converged; identical */		(dict_basewid (dict, h1->dagnode->wid) != dict_basewid (dict, h2->dagnode->wid)))		break;	}	if (h1 == h2) {	    /* Identical history already exists */	    if (hlist->pscr >= pscr)	/* Existing history is superior */		return 1;	    else {		/*		 * New path is better; prune existing one.  There may be other paths		 * in the list as well, but all of them must have been pruned by		 * hlist or others later in the list!		 */		hlist->pruned = 1;		return 0;	    }	}    }    return 0;}/** * Create a new ppath node for dagnode reached from top via link l.  Assign the * proper scores for the new, extended path and insert it in the sorted heap of * ppath nodes.  But first check if it's a duplicate of an existing ppath but has an * inferior score, in which case do not insert. */static void ppath_insert (ppath_t *top, daglink_t *l, int32 pscr, int32 tscr, int32 lscr){    ppath_t *pp, *lmhist;    uint32 h, hmod;    s3wid_t w;        /* Extend path score; Add acoustic and LM scores for link */    pscr = top->pscr + l->ascr + lscr;        /*     * Check if extended path would be a duplicate one with an inferior score.     * First, find hash value for new node.     */    lmhist = filler_word(top->dagnode->wid) ? top->lmhist : top;    w = lmhist->dagnode->wid;    h = lmhist->histhash - w + dict_basewid (dict, w);    h = (h >> 5) | (h << 27);	/* Rotate right 5 bits */    h += l->node->wid;    hmod = h % HISTHASH_MOD;        /* If new node would be an inferior duplicate, skip creating it */    if (ppath_dup (hash_list[hmod], lmhist, l->node, h, pscr))	return;    /* Add heuristic score from END OF l until end of utterance */    tscr = pscr + l->hscr;    /* Initialize new partial path node */    pp = (ppath_t *) listelem_alloc (sizeof(ppath_t));    pp->dagnode = l->node;    pp->hist = top;    pp->lmhist = lmhist;    pp->lscr = lscr;    pp->pscr = pscr;    pp->tscr = tscr;    pp->histhash = h;    pp->hashnext = hash_list[hmod];    hash_list[hmod] = pp;    pp->pruned = 0;    pp->next = ppath_list;    ppath_list = pp;        heap_root = aheap_insert (heap_root, pp);        n_ppath++;}static int32 ppath_free ( void ){    ppath_t *pp;    int32 n;        n = 0;    while (ppath_list) {	pp = ppath_list->next;	listelem_free ((char *) ppath_list, sizeof(ppath_t));	ppath_list = pp;	n++;    }        return n;}static void ppath_seg_write (FILE *fp, ppath_t *pp, int32 ascr){    int32 lscr_base;        if (pp->hist)	ppath_seg_write (fp, pp->hist, pp->pscr - pp->hist->pscr - pp->lscr);    lscr_base = pp->hist ? lm_rawscore (lm, pp->lscr, 1.0) : 0;    fprintf (fp, " %d %d %d %s",	     pp->dagnode->sf, ascr, lscr_base, dict_wordstr (dict, pp->dagnode->wid));}static void nbest_hyp_write (FILE *fp, ppath_t *top, int32 pscr, int32 nfr){    int32 lscr, lscr_base;    ppath_t *pp;        lscr_base = 0;    for (lscr = 0, pp = top; pp; lscr += pp->lscr, pp = pp->hist) {	if (pp->hist)	    lscr_base += lm_rawscore (lm, pp->lscr, 1.0);	else	    assert (pp->lscr == 0);    }    fprintf (fp, "T %d A %d L %d", pscr, pscr - lscr, lscr_base);    ppath_seg_write (fp, top, pscr - top->pscr);    fprintf (fp, " %d\n", nfr);    fflush (fp);}#if 0static ppath_dump (ppath_t *p){    printf ("PPATH:\n");    for (; p; p = p->hist) {	printf ("pscr= %11d, lscr= %9d, tscr= %11d, hash= %11u, pruned= %d, sf= %5d, %s\n",		p->pscr, p->lscr, p->tscr, p->histhash, p->pruned,		p->dagnode->sf, dict_wordstr (dict, p->dagnode->wid));    }}#endifvoid nbest_search (char *filename, char *uttid){    FILE *fp;    float32 f32arg;    float64 f64arg;    int32 nbest_max, n_pop, n_exp, n_hyp, n_pp;    int32 besthyp, worsthyp, besttscr;    ppath_t *top, *pp;    dagnode_t *d;    daglink_t *l;    int32 lscr, pscr, tscr;    s3wid_t bw0, bw1, bw2;    int32 i, k;    int32 ispipe;    int32 ppathdebug;        /* Create Nbest file and write header comments */    if ((fp = fopen_comp (filename, "w", &ispipe)) == NULL) {	E_ERROR("fopen_comp (%s,w) failed\n", filename);	fp = stdout;    }    fprintf (fp, "# %s\n", uttid);    fprintf (fp, "# frames %d\n", dag.nfrm);    f32arg = *((float32 *) cmd_ln_access ("-logbase"));    fprintf (fp, "# logbase %e\n", f32arg);    f32arg = *((float32 *) cmd_ln_access ("-lw"));    fprintf (fp, "# langwt %e\n", f32arg);    f32arg = *((float32 *) cmd_ln_access ("-inspen"));    fprintf (fp, "# inspen %e\n", f32arg);    f64arg = *((float64 *) cmd_ln_access ("-beam"));    fprintf (fp, "# beam %e\n", f64arg);    ppathdebug = *((int32 *) cmd_ln_access ("-ppathdebug"));        assert (heap_root == NULL);    assert (ppath_list == NULL);        /*     * Set limit on max LM ops allowed after which utterance is aborted.     * Limit is lesser of absolute max and per frame max.     */    maxlmop = *((int32 *) cmd_ln_access ("-maxlmop"));    k = *((int32 *) cmd_ln_access ("-maxlpf"));    k *= dag.nfrm;    if (maxlmop > k)	maxlmop = k;    lmop = 0;        /* Set limit on max #ppaths allocated before aborting utterance */    maxppath = *((int32 *) cmd_ln_access ("-maxppath"));    n_ppath = 0;        for (i = 0; i < HISTHASH_MOD; i++)	hash_list[i] = NULL;        /* Insert start node into heap and into list of nodes-by-frame */    pp = (ppath_t *) listelem_alloc (sizeof(ppath_t));    pp->dagnode = dag.entry.node;    pp->hist = NULL;    pp->lmhist = NULL;    pp->lscr = 0;    pp->pscr = 0;    pp->tscr = 0;	/* HACK!! Not really used as it is popped off rightaway */    pp->histhash = pp->dagnode->wid;    pp->hashnext = NULL;    pp->pruned = 0;        pp->next = NULL;    ppath_list = pp;    /* Insert into heap of partial paths to be expanded */    heap_root = aheap_insert (heap_root, pp);        /* Insert at head of (empty) list of ppaths with same hashmod value */    hash_list[pp->histhash % HISTHASH_MOD] = pp;        /* Astar-search */    n_hyp = n_pop = n_exp = n_pp = 0;    nbest_max = *((int32 *) cmd_ln_access ("-nbest"));    besthyp = besttscr = (int32)0x80000000;    worsthyp = (int32)0x7fffffff;        while ((n_hyp < nbest_max) && heap_root) {	/* Extract top node from heap */	top = heap_root->ppath;	heap_root = aheap_pop (heap_root);		n_pop++;		if (top->pruned)	    continue;		if (top->dagnode == dag.exit.node) {	/* Complete hypotheses; output */	    nbest_hyp_write (fp, top, top->pscr + dag.exit.ascr, dag.nfrm);	    n_hyp++;	    if (besthyp < top->pscr)		besthyp = top->pscr;	    if (worsthyp > top->pscr)		worsthyp = top->pscr;	    	    continue;	}		/* Find two word (trigram) history beginning at this node */	pp = (filler_word (top->dagnode->wid)) ? top->lmhist : top;	if (pp) {	    bw1 = dict_basewid(dict, pp->dagnode->wid);	    pp = pp->lmhist;	    bw0 = pp ? dict_basewid(dict, pp->dagnode->wid) : BAD_S3WID;	} else	    bw0 = bw1 = BAD_S3WID;		/* Expand to successors of top (i.e. via each link leaving top) */	d = top->dagnode;	for (l = d->succlist; l; l = l->next) {	    assert (l->node->reachable && (! l->is_filler_bypass));	    /* Obtain LM score for link */	    bw2 = dict_basewid (dict, l->node->wid);	    lscr = (filler_word (bw2)) ? fillpen(fpen, bw2) : lm_tg_score (lm, dict2lmwid[bw0], dict2lmwid[bw1], dict2lmwid[bw2],bw2);	    if (lmop++ > maxlmop) {		E_ERROR("%s: Max LM ops (%d) exceeded\n", uttid, maxlmop);		break;	    }	    	    /* Obtain partial path score and hypothesized total utt score */	    pscr = top->pscr + l->ascr + lscr;	    tscr = pscr + l->hscr;	    if (ppathdebug) {		printf ("pscr= %11d, tscr= %11d, sf= %5d, %s%s\n",			pscr, tscr, l->node->sf, dict_wordstr(dict, l->node->wid),			(tscr-beam >= besttscr) ? "" : " (pruned)");	    }	    	    /* Insert extended path if within beam of best so far */	    if (tscr - beam >= besttscr) {		ppath_insert (top, l, pscr, tscr, lscr);		if (n_ppath > maxppath) {		    E_ERROR("%s: Max PPATH limit (%d) exceeded\n", uttid, maxppath);		    break;		}				if (tscr > besttscr)		    besttscr = tscr;	    }	}	if (l)	/* Above loop was aborted */	    break;		n_exp++;    }    fprintf (fp, "End; best %d worst %d diff %d beam %d\n",	     besthyp + dag.exit.ascr, worsthyp + dag.exit.ascr, worsthyp - besthyp, beam);    fclose_comp (fp, ispipe);    if (n_hyp <= 0) {	unlink (filename);	E_ERROR("%s: A* search failed\n", uttid);    }        /* Free partial path nodes and any unprocessed heap */    while (heap_root)	heap_root = aheap_pop(heap_root);    n_pp = ppath_free ();    printf ("CTR(%s): %5d frm %4d hyp %6d pop %6d exp %8d pp\n",	    uttid, dag.nfrm, n_hyp, n_pop, n_exp, n_pp);}void nbest_init ( void ){    float64 *f64arg;    float32 lw, wip;    int32 fudge;    fudge = *((int32 *) cmd_ln_access ("-dagfudge"));    if ((fudge < 0) || (fudge > 2))	E_FATAL("Bad -dagfudge argument: %d, must be in range 0..2\n", fudge);        /* dict = dict_getdict (); */    /* Some key word ids */    startwid = dict_wordid (dict, S3_START_WORD);    finishwid = dict_wordid (dict, S3_FINISH_WORD);    if ((NOT_S3WID(startwid)) || (NOT_S3WID(finishwid)))	E_FATAL("%s or %s missing from dictionary\n", S3_START_WORD, S3_FINISH_WORD);    lw = *((float32 *) cmd_ln_access("-lw"));    wip = *((float32 *) cmd_ln_access("-inspen"));    f64arg = (float64 *) cmd_ln_access ("-beam");    beam = logs3 (*f64arg);    E_INFO("beam= %d\n", beam);        /* Initialize DAG and nbest search structures */    dag.list = NULL;    heap_root = NULL;    ppath_list = NULL;    hash_list = (ppath_t **) ckd_calloc (HISTHASH_MOD, sizeof(ppath_t *));}int32 dag_destroy ( void ){    dagnode_t *d, *nd;    daglink_t *l, *nl;        for (d = dag.list; d; d = nd) {	nd = d->alloc_next;		for (l = d->succlist; l; l = nl) {	    nl = l->next;	    listelem_free ((char *)l, sizeof(daglink_t));	}	for (l = d->predlist; l; l = nl) {	    nl = l->next;	    listelem_free ((char *)l, sizeof(daglink_t));	}	listelem_free ((char *)d, sizeof(dagnode_t));    }    dag.list = NULL;    return 0;}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?