⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mbr.cpp

📁 moses开源的机器翻译系统
💻 CPP
字号:
#include <iostream>#include <fstream>#include <sstream>#include <iomanip>#include <vector>#include <map>#include <stdlib.h>#include <math.h>#include <algorithm>#include <stdio.h>#include <unistd.h>using namespace std ;/* Input :    1. a sorted  n-best list, with duplicates filtered out in the following  format    0 ||| amr moussa is currently on a visit to libya , tomorrow , sunday , to hold talks with regard to the in sudan . ||| 0 -4.94418 0 0 -2.16036 0 0 -81.4462 -106.593 -114.43 -105.55 -12.7873 -26.9057 -25.3715 -52.9336 7.99917 -24 ||| -4.58432   2. a weight vector    3. bleu order ( default = 4)   4. scaling factor to weigh the weight vector (default = 1.0)   Output :   translations that minimise the Bayes Risk of the n-best list*/int TABLE_LINE_MAX_LENGTH = 5000;vector<double> weights;float SCALE = 1.0;int BLEU_ORDER = 4;int SMOOTH = 1;int DEBUG = 0;double min_interval = 1e-4;#define SAFE_GETLINE(_IS, _LINE, _SIZE, _DELIM) {_IS.getline(_LINE, _SIZE, _DELIM); if(_IS.fail() && !_IS.bad() && !_IS.eof()) _IS.clear();}typedef string WORD;typedef unsigned int WORD_ID;map<WORD, WORD_ID>  lookup;vector< WORD > vocab;class candidate_t{  public:    vector<WORD_ID> translation;     vector<double> features;    int translation_size;} ;void usage(void){    fprintf(stderr,	    "usage: mbr -s SCALE -n BLEU_ORDER -w weights.txt -i nbest.txt"); }char *strstrsep(char **stringp, const char *delim) {  char *match, *save;  save = *stringp;  if (*stringp == NULL)    return NULL;  match = strstr(*stringp, delim);  if (match == NULL) {    *stringp = NULL;    return save;  }  *match = '\0';  *stringp = match + strlen(delim);  return save;}vector<string> tokenize( const char input[] ){  vector< string > token;  bool betweenWords = true;  int start;  int i=0;  for(; input[i] != '\0'; i++)  {    bool isSpace = (input[i] == ' ' || input[i] == '\t');    if (!isSpace && betweenWords)    {      start = i;      betweenWords = false;    }    else if (isSpace && !betweenWords)    {      token.push_back( string( input+start, i-start ) );      betweenWords = true;    }  }  if (!betweenWords)    token.push_back( string( input+start, i-start+1 ) );  return token;}  WORD_ID storeIfNew( WORD word ){  if( lookup.find( word ) != lookup.end() )    return lookup[ word ];  WORD_ID id = vocab.size();  vocab.push_back( word );  lookup[ word ] = id;  return id;}int count( string input, char delim ){  int count = 0;  for ( int i = 0; i < input.size(); i++){      if ( input[i] == delim)         count++;  }  return count;}double calculate_probability(const vector<double> & feats, const vector<double> & weights,double SCALE){    if (feats.size() != weights.size())         cerr << "ERROR : Number of features <> number of weights " << endl;       double prob = 0;    for ( int i = 0; i < feats.size(); i++){        prob += feats[i]*weights[i]*SCALE;    }    return exp(prob);}void extract_ngrams(const vector<WORD_ID>& sentence, map < vector < WORD_ID>, int >  & allngrams){  vector< WORD_ID> ngram;  for (int k = 0; k< BLEU_ORDER; k++)  {    for(int i =0; i < max((int)sentence.size()-k,0); i++)    {      for ( int j = i; j<= i+k; j++)      {        ngram.push_back(sentence[j]);      }      ++allngrams[ngram];      ngram.clear();    }  }}double calculate_score(const vector<candidate_t*> & sents, int ref, int hyp,  vector < map < vector < WORD_ID>, int > > & ngram_stats ) {  int comps_n = 2*BLEU_ORDER+1;  int comps[comps_n];  double logbleu = 0.0, brevity;    int hyp_length = sents[hyp]->translation_size;  for (int i =0; i<BLEU_ORDER;i++)  {    comps[2*i] = 0;    comps[2*i+1] = max(hyp_length-i,0);  }  map< vector < WORD_ID > ,int > & hyp_ngrams = ngram_stats[hyp] ;  map< vector < WORD_ID >, int > & ref_ngrams = ngram_stats[ref] ;  for (map< vector< WORD_ID >, int >::iterator it = hyp_ngrams.begin();       it != hyp_ngrams.end(); it++)  {    map< vector< WORD_ID >, int >::iterator ref_it = ref_ngrams.find(it->first);    if(ref_it != ref_ngrams.end())    {      comps[2* (it->first.size()-1)] += min(ref_it->second,it->second);    }  }  comps[comps_n-1] = sents[ref]->translation_size;  if (DEBUG)  {    for ( int i = 0; i < comps_n; i++)      cerr << "Comp " << i << " : " << comps[i];  }  for (int i=0; i<BLEU_ORDER; i++)  {    if (comps[0] == 0)      return 0.0;    if ( i > 0 )      logbleu += log(static_cast<double>(comps[2*i]+SMOOTH))-log(static_cast<double>(comps[2*i+1]+SMOOTH));    else      logbleu += log(static_cast<double>(comps[2*i]))-log(static_cast<double>(comps[2*i+1]));  }  logbleu /= BLEU_ORDER;  brevity = 1.0-(double)comps[comps_n-1]/comps[1]; // comps[comps_n-1] is the ref length, comps[1] is the test length  if (brevity < 0.0)    logbleu += brevity;  return exp(logbleu);}vector<double> read_weights(string fileName){  ifstream inFile;  inFile.open(fileName.c_str());  istream *inFileP = &inFile;  char line[TABLE_LINE_MAX_LENGTH];  int i=0;  vector<double> weights;  while(true)  {    i++;    SAFE_GETLINE((*inFileP), line, TABLE_LINE_MAX_LENGTH, '\n');    if (inFileP->eof()) break;    vector<string> token = tokenize(line);        for (int j = 0; j < token.size(); j++){         weights.push_back(atof(token[j].c_str()));    }   }  cerr << endl;  return weights;}int find_pos_of_min_element(const vector<double>& vec){    int min_pos = -1;    double min_element = 10000;    for ( int i = 0; i < vec.size(); i++){          if (vec[i] < min_element){              min_element = vec[i];              min_pos = i;          }    }/*    cerr << "Min pos is : " << min_pos << endl;    cerr << "Min mbr loss is : " << min_element << endl;*/    return min_pos; }void process(int sent, const vector<candidate_t*> & sents){//   cerr << "Sentence " << sent << " has " << sents.size() << " candidate translations" << endl;   double marginal = 0;   vector<double> joint_prob_vec;   double joint_prob;   vector< map < vector <WORD_ID>, int > > ngram_stats;   for (int i = 0; i < sents.size(); i++){//         cerr << "Sents " << i << " has trans : " << sents[i]->translation << endl;        //Calculate marginal and cache the posteriors        joint_prob = calculate_probability(sents[i]->features,weights,SCALE);        marginal += joint_prob;        joint_prob_vec.push_back(joint_prob);        //Cache ngram counts        map < vector <WORD_ID>, int > counts;        extract_ngrams(sents[i]->translation,counts);        ngram_stats.push_back(counts);   }   //cerr << "Marginal is " << marginal;   vector<double> mbr_loss;   double bleu, weightedLoss;   double weightedLossCumul = 0;   double minMBRLoss = 1000000;   int minMBRLossIdx = -1;      /* Main MBR computation done here */   for (int i = 0; i < sents.size(); i++){       weightedLossCumul = 0;       for (int j = 0; j < sents.size(); j++){            if ( i != j) {               bleu = calculate_score(sents, j, i,ngram_stats );               weightedLoss = ( 1 - bleu) * ( joint_prob_vec[j]/marginal);               weightedLossCumul += weightedLoss;               if (weightedLossCumul > minMBRLoss)                   break;             }       }       if (weightedLossCumul < minMBRLoss){           minMBRLoss = weightedLossCumul;           minMBRLossIdx = i;       }   }//    cerr << "Min pos is : " << minMBRLossIdx << endl;//    cerr << "Min mbr loss is : " << minMBRLoss << endl;   /* Find sentence that minimises Bayes Risk under 1- BLEU loss */   vector<  WORD_ID > best_translation = sents[minMBRLossIdx]->translation;   for (int i = 0; i < best_translation.size(); i++)       cout << vocab[best_translation[i]] << " " ;   cout << endl;}void read_nbest_data(string fileName){   FILE * fp;   fp = fopen (fileName.c_str() , "r");   static char buf[10000];   char *rest, *tok;   int field;   int sent_i, cur_sent;   candidate_t *cand = NULL;   vector<candidate_t*> testsents;      cur_sent = -1;       while (fgets(buf, sizeof(buf), fp) != NULL) {	field = 0;	rest = buf;	while ((tok = strstrsep(&rest, "|||")) != NULL) {            if (field == 0) {		sent_i = strtol(tok, NULL, 10);                cand = new candidate_t;	    } else if (field == 2) {                   vector<double> features;                   char * subtok;                   subtok = strtok (tok," ");  		   while (subtok != NULL)                   {    		     features.push_back(atof(subtok));    		     subtok = strtok (NULL, " ");  	           }                   cand->features = features;            } else if (field == 1) {                vector<string> trans_str = tokenize(tok);                vector<WORD_ID> trans_int;                for (int j=0; j<trans_str.size(); j++)                {                    trans_int.push_back( storeIfNew( trans_str[j] ) );                }                cand->translation= trans_int;                cand->translation_size = cand->translation.size();            } else if (field == 3) {                continue;            }             else {		fprintf(stderr, "too many fields in n-best list line\n");	    }            field++;	}        if (sent_i != cur_sent){           if (cur_sent != - 1) {              process(cur_sent,testsents);           }           cur_sent = sent_i;           testsents.clear();        }    testsents.push_back(cand);    }    process(cur_sent,testsents);    cerr << endl;}int main(int argc, char **argv){    time_t starttime = time(NULL);    int c;        string f_weight = "";    string f_nbest = "";        while ((c = getopt(argc, argv, "s:w:n:i:")) != -1) {	switch (c) {	case 's':	    SCALE = atof(optarg);	    break;	case 'n':	    BLEU_ORDER = atoi(optarg);	    break;	case 'w':	    f_weight =  optarg;	    break;	case 'i':	    f_nbest = optarg;	    break;	default:	    usage();	}    }    argc -= optind;    argv += optind;    if (argc < 2) {	usage();    }           weights = read_weights(f_weight);        read_nbest_data(f_nbest);        time_t endtime = time(NULL);    cerr << "Processed data in" << (endtime-starttime) << " seconds\n";}    

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -