⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mbr.cpp.svn-base

📁 moses开源的机器翻译系统
💻 SVN-BASE
字号:
#include <iostream>#include <fstream>#include <sstream>#include <iomanip>#include <vector>#include <map>#include <stdlib.h>#include <math.h>#include <algorithm>#include <stdio.h>#include "TrellisPathList.h"#include "TrellisPath.h"#include "StaticData.h"#include "Util.h"#include "mbr.h"using namespace std ;/* Input :    1. a sorted  n-best list, with duplicates filtered out in the following  format    0 ||| amr moussa is currently on a visit to libya , tomorrow , sunday , to hold talks with regard to the in sudan . ||| 0 -4.94418 0 0 -2.16036 0 0 -81.4462 -106.593 -114.43 -105.55 -12.7873 -26.9057 -25.3715 -52.9336 7.99917 -24 ||| -4.58432   2. a weight vector    3. bleu order ( default = 4)   4. scaling factor to weigh the weight vector (default = 1.0)   Output :   translations that minimise the Bayes Risk of the n-best list*/int BLEU_ORDER = 4;int SMOOTH = 1;int DEBUG = 0;float min_interval = 1e-4;void extract_ngrams(const vector<const Factor* >& sentence, map < vector < const Factor* >, int >  & allngrams){  vector< const Factor* > ngram;  for (int k = 0; k < BLEU_ORDER; k++)  {    for(int i =0; i < max((int)sentence.size()-k,0); i++)    {      for ( int j = i; j<= i+k; j++)      {        ngram.push_back(sentence[j]);      }      ++allngrams[ngram];      ngram.clear();    }  }}float calculate_score(const vector< vector<const Factor*> > & sents, int ref, int hyp,  vector < map < vector < const Factor *>, int > > & ngram_stats ) {  int comps_n = 2*BLEU_ORDER+1;  vector<int> comps(comps_n);  float logbleu = 0.0, brevity;    int hyp_length = sents[hyp].size();  for (int i =0; i<BLEU_ORDER;i++)  {    comps[2*i] = 0;    comps[2*i+1] = max(hyp_length-i,0);  }  map< vector < const Factor * > ,int > & hyp_ngrams = ngram_stats[hyp] ;  map< vector < const Factor * >, int > & ref_ngrams = ngram_stats[ref] ;  for (map< vector< const Factor * >, int >::iterator it = hyp_ngrams.begin();       it != hyp_ngrams.end(); it++)  {    map< vector< const Factor * >, int >::iterator ref_it = ref_ngrams.find(it->first);    if(ref_it != ref_ngrams.end())    {      comps[2* (it->first.size()-1)] += min(ref_it->second,it->second);    }  }  comps[comps_n-1] = sents[ref].size();  if (DEBUG)  {    for ( int i = 0; i < comps_n; i++)      cerr << "Comp " << i << " : " << comps[i];  }  for (int i=0; i<BLEU_ORDER; i++)  {    if (comps[0] == 0)      return 0.0;    if ( i > 0 )      logbleu += log((float)comps[2*i]+SMOOTH)-log((float)comps[2*i+1]+SMOOTH);    else      logbleu += log((float)comps[2*i])-log((float)comps[2*i+1]);  }  logbleu /= BLEU_ORDER;  brevity = 1.0-(float)comps[comps_n-1]/comps[1]; // comps[comps_n-1] is the ref length, comps[1] is the test length  if (brevity < 0.0)    logbleu += brevity;  return exp(logbleu);}vector<const Factor*> doMBR(const TrellisPathList& nBestList){//   cerr << "Sentence " << sent << " has " << sents.size() << " candidate translations" << endl;  float marginal = 0;  vector<float> joint_prob_vec;  vector< vector<const Factor*> > translations;  float joint_prob;  vector< map < vector <const Factor *>, int > > ngram_stats;  TrellisPathList::const_iterator iter;  TrellisPath* hyp = NULL;	for (iter = nBestList.begin() ; iter != nBestList.end() ; ++iter)	{		const TrellisPath &path = **iter;    joint_prob = UntransformScore(StaticData::Instance().GetMBRScale() * path.GetScoreBreakdown().InnerProduct(StaticData::Instance().GetAllWeights()));    marginal += joint_prob;    joint_prob_vec.push_back(joint_prob);    //Cache ngram counts    map < vector < const Factor *>, int > counts;    vector<const Factor*> translation;    GetOutputFactors(path, translation);        //TO DO    extract_ngrams(translation,counts);    ngram_stats.push_back(counts);    translations.push_back(translation);   }      vector<float> mbr_loss;   float bleu, weightedLoss;   float weightedLossCumul = 0;   float minMBRLoss = 1000000;   int minMBRLossIdx = -1;      /* Main MBR computation done here */   for (int i = 0; i < nBestList.GetSize(); i++){       weightedLossCumul = 0;       for (int j = 0; j < nBestList.GetSize(); j++){            if ( i != j) {               bleu = calculate_score(translations, j, i,ngram_stats );               weightedLoss = ( 1 - bleu) * ( joint_prob_vec[j]/marginal);               weightedLossCumul += weightedLoss;               if (weightedLossCumul > minMBRLoss)                   break;             }       }       if (weightedLossCumul < minMBRLoss){           minMBRLoss = weightedLossCumul;           minMBRLossIdx = i;       }   }   /* Find sentence that minimises Bayes Risk under 1- BLEU loss */   return translations[minMBRLossIdx];}void GetOutputFactors(const TrellisPath &path, vector <const Factor*> &translation){	const std::vector<const Hypothesis *> &edges = path.GetEdges();	const std::vector<FactorType>& outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();	assert (outputFactorOrder.size() == 1);	// print the surface factor of the translation	for (int currEdge = (int)edges.size() - 1 ; currEdge >= 0 ; currEdge--)	{		const Hypothesis &edge = *edges[currEdge];		const Phrase &phrase = edge.GetCurrTargetPhrase();		size_t size = phrase.GetSize();		for (size_t pos = 0 ; pos < size ; pos++)		{						const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[0]);			translation.push_back(factor);		}	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -