decoder.cpp

来自「解码器是基于短语的统计机器翻译系统的核心模块」· C++ 代码 · 共 822 行 · 第 1/2 页

CPP
822
字号
			pp1 = make_pair(i, i+len);
			if (MaxP>FutureCost[pp1])
			   FutureCost[pp1] = MaxP;
		}
	}
}

/************************************************************************/
/* Beam Search                                                          */
/************************************************************************/
void Decoder::BeamSearch(int WordLen, const map<pair<int,int>,TransMap> &TransOption, 
						vector<vector<Hypotheses> > &HpStack,  map<pair<int,int>,double > &FutureCost)
{
	HpStack.resize(WordLen+1);
	Hypotheses hpinit(WordLen);
	Feature featinit(lambda.size());
	hpinit.feat = featinit;
	HpStack[0].push_back(hpinit);

	int i,j;
	int start,end;

	for (i=0; i<HpStack.size(); i++)
	{
		if (INFO)
		{
			logs << "\n<Stack ID=\"" << i << "\" size=\"" << HpStack[i].size() << "\">" <<endl;
			for (int k=0; k<HpStack[i].size(); k++)
			{
				logs << "\n<ID stack=\"" << i << "\"" << " number=\"" << k << "\">" << endl;
				HpStack[i][k].ShowHP(logs);
				logs << "</ID>" << endl;
			}
			logs << "</Stack>" << endl;
		}

		if (i == HpStack.size()-1)//the last stack don't need to extend
			 break;

		for (j=0; j<HpStack[i].size(); j++)
		{
			//extend a hypothesis in stack i  
			map<pair<int,int>, TransMap>::const_iterator it;  
			for (it = TransOption.begin(); it!= TransOption.end(); it++)
			{
				int i_Distortion_Dist = GetDistortionDistance( HpStack[i][j].CoveredWord,HpStack[i][j].CurBegEnd, (*it).first);
				if (i_Distortion_Dist == -1)  
					continue;

				//extend
				TransMap PhraseTrans = (*it).second;//Translation information for a Chinese phrase
				TransMap::iterator it2 = PhraseTrans.begin();
				for (it2=PhraseTrans.begin(); it2!=PhraseTrans.end(); it2++)
				{
					string LastEng = HpStack[i][j].LastEnglishWord;
					string phrasetrans = (*it2).first;

					vector<double> prob;
					for (int t=1; t<(*it2).second.size(); t++)
						 prob.push_back((*it2).second[t]);

					double lmprob = lm.getLMProb(LastEng,phrasetrans,lmngram);
					prob.push_back(lmprob);//LM model

					prob.push_back(-1 * i_Distortion_Dist); //distortion model

					Feature feat(lambda.size());
					feat.featfunc = prob;

					Hypotheses NewHp(i,j,(*it).first,phrasetrans,feat,HpStack[i][j],lmngram);
					NewHp.GetFutureCost(FutureCost);
					NewHp.ComputeFeature(HpStack[i][j].feat,lambda);

					//push the Hypothesis into stack 
					AddToStack(HpStack[NewHp.CoveredNumber],NewHp);

				}//End PhraseTrans
			}//End TransOption
		}//End HpStack[i]
	}//End HpStack
}


/************************************************************************/
/* Generate Nbest-list                                                  */
/************************************************************************/
void Decoder::GenerateNbest(vector<vector<Hypotheses> > &HpStack, vector<CandTrans> &CandNbest)
{
	int i,j,k;
	vector<AddArc> NbestSearchNode;

	int pos=HpStack.size()-1;
	while (HpStack[pos].size() == 0)   
		pos--;

	for (i=0; i<HpStack[pos].size(); i++)
	{
		string e1 = HpStack[pos][i].CurEnglishTranslation;
		int p1 = HpStack[pos][i].PreStack;
		int p2 = HpStack[pos][i].PreStackNum;
		AddArc a1(p1,p2,HpStack[pos][i].feat,e1);
		insert_to_vec(NbestSearchNode,NBEST_LIST,a1);

		for (j=0; j<HpStack[pos][i].AdditionalArcs.size();j++)
		{
			insert_to_vec(NbestSearchNode, NBEST_LIST, HpStack[pos][i].AdditionalArcs[j]);
		}       
	}

	while (!NbestSearchNode.empty()) 
	{
		AddArc next = popmax_from_vec(NbestSearchNode);

		int p1 = next.PreStack;
		int p2 = next.PreStackNum;

		if(p1 == 0)// head
		{
			CandTrans ct;
			ct.english = next.english;
			ct.feat = next.feat;
			insert_to_vec(CandNbest,NBEST_LIST,ct);
			if(CandNbest.size() == NBEST_LIST)
				break;
		}
		else if( p1> 0 )
		{
			string e1 = HpStack[p1][p2].CurEnglishTranslation +" " + next.english;
			Feature feat =  HpStack[p1][p2].feat;
			feat = feat + next.feat; 
			int pp1 = HpStack[p1][p2].PreStack;
			int pp2 = HpStack[p1][p2].PreStackNum;
			AddArc aa2(pp1,pp2,feat,e1);
			insert_to_vec(NbestSearchNode,NBEST_LIST,aa2);

			int t;
			for(t=0; t<HpStack[p1][p2].AdditionalArcs.size();t++)
			{
				int pa1 = HpStack[p1][p2].AdditionalArcs[t].PreStack;
				int pa2 = HpStack[p1][p2].AdditionalArcs[t].PreStackNum;
				string e2 =  HpStack[p1][p2].AdditionalArcs[t].english + " " + next.english;
				Feature f2 = HpStack[p1][p2].AdditionalArcs[t].feat;
				f2 =f2+ next.feat;
				AddArc aa3(pa1,pa2,f2,e2);
				insert_to_vec(NbestSearchNode,NBEST_LIST,aa3);
			}       
		}
	}

//	sort(CandNbest.begin(),CandNbest.end(),greater<CandTrans>());

	if(INFO)
	{
		logs << "\n<Nbest_list Number = \"" << NBEST_LIST << "\">" << endl;

		for (i=0; i<CandNbest.size(); i++)
		{
			logs << "<Candidate No=\"" << i+1 << "\">" << endl;
			CandNbest[i].Show(logs);
			logs << "</Candidate>" << endl;
		}

		logs << "</Nbest_list>" << endl;
	}

}

/************************************************************************/
/* Generate 1 best                                                      */
/************************************************************************/
void Decoder::Generate1best(vector<vector<Hypotheses> > &HpStack, vector<CandTrans> &CandNbest)
{
	int pos = HpStack.size() - 1;
	while(HpStack[pos].size()==0)
		pos--;

	CandTrans best;

 	int num=0;
	for (int i=0; i<HpStack[pos].size(); i++)
	{
		if(HpStack[pos][i].feat > HpStack[pos][num].feat)
			num = i;
	}

	int PreStack = HpStack[pos][num].PreStack;
	int PreStackNum = HpStack[pos][num].PreStackNum;
	best.english = HpStack[pos][num].CurEnglishTranslation;
    best.feat = HpStack[pos][num].feat;
	while (PreStack > 0)
	{
		int p1 = PreStack;
		int p2 = PreStackNum;
		PreStack = HpStack[p1][p2].PreStack;
		PreStackNum = HpStack[p1][p2].PreStackNum;

		Feature feat = HpStack[p1][p2].feat;
		feat = feat + best.feat;
		best.feat = feat;

		best.english = HpStack[p1][p2].CurEnglishTranslation + " " + best.english;
	}
	CandNbest.push_back(best);
}
/************************************************************************/
/* Push a hypothesis to HpStack                                         */
/************************************************************************/
bool Decoder::AddToStack(vector<Hypotheses> &HpStack, Hypotheses &hp)
{
	int i = 0, pos = 0;
	bool IsRecombine=false;
	for (i=0; i<HpStack.size(); i++)
	{
		if (HpStack[i] < HpStack[pos])
		{
			pos = i;
		}
		if (HpStack[i] == hp)//recombine
		{
			if (HpStack[i] < hp)
			{
				hp.AdditionalArcs = HpStack[i].AdditionalArcs;
				AddArc aarc(HpStack[i]) ;	
			//	hp.AdditionalArcs.push_back(aarc);
				insert_to_vec(hp.AdditionalArcs,HP_STACK_SIZE,aarc);
				HpStack[i] = hp;
			}
			else
			{
				AddArc aarc(hp);
				//HpStack[i].AdditionalArcs.push_back(aarc);
				insert_to_vec(hp.AdditionalArcs,HP_STACK_SIZE,aarc);
			}
			IsRecombine = true;
			break;
		}
	}

	if (i>=HpStack.size())
	{
		if (HpStack.size() == HP_STACK_SIZE)
		 {
		     if(HpStack[pos] < hp)
			HpStack[pos] = hp;
		 }
		else
		    HpStack.push_back(hp);
		IsRecombine = false;
	}

	return IsRecombine;
}

/************************************************************************/
/* Get distortion distance                                              */
/************************************************************************/
int Decoder::GetDistortionDistance(const vector<int> CoveredWord,
								   const pair<int,int> LastPhraseBeginEnd,
								   const pair<int,int> BeginEnd)
{
	int i,distortion_dist,j;
	for (i = BeginEnd.first; i <= BeginEnd.second; i++)
	{
		if (CoveredWord[i] == 1)
			return -1;
	}

	for (i = BeginEnd.first - 1; i>= 0 ; i--)
	{
		if (CoveredWord[i] == 0)
			break;
	}

	for (j = BeginEnd.second +1; j< CoveredWord.size(); j++)
	{
		if (CoveredWord[j] == 0)
			break;
	}

	if (i >= 0)
	{
		if ( j<CoveredWord.size() && (j-i) >= MAX_DISTORTION)
			return -1;

		if (j != CoveredWord.size())  
		{
			while (CoveredWord[i] == 0 && i>= 0)
				 i--;

			if (i > 0)
			{
				while (CoveredWord[i] == 1 && i>= 0)
					 i--;

				if (i >=0)
				{
					if ((BeginEnd.second - i) >= MAX_DISTORTION)
						 return -1;
				}
			}
		}
	}

	distortion_dist = abs(LastPhraseBeginEnd.second + 1 - BeginEnd.first);
	if ( distortion_dist <= MAX_DISTORTION)
		 return distortion_dist;
	else
		 return -1;
}

/************************************************************************/
/* Read Chinese (format 863)                                            */
/************************************************************************/
void Decoder::ReadChinese(const char *filename, vector<string> &sents)
{
	ifstream in(filename);
	if (!in)
	{
		cerr << "open test file error in [Decoder::ReadChinese] " << filename << endl;
		exit(1);
	}

	string line;
	while (getline(in,line))
	{
		if (line.find("<s id") != string::npos)
		{
			int pos = line.find_first_of(">",0);
			line.erase(0,pos+1);

			pos = line.find_last_of("<",line.size()-1);
			line.erase(pos,line.size()-pos);
			sents.push_back(line);
		}
	}
}

/************************************************************************/
/* Change Format                                                        */
/************************************************************************/
void Decoder::ChangeFormatTo863(const char *srcfile, const char *tempfile, const char *resultfile)
{
	ifstream in1(srcfile);
	ifstream in2(tempfile);
    ofstream out(resultfile);

	string line;
	vector<string> srcSent;
	vector<string> tgtSent;

	while (getline(in1, line))
	{
		srcSent.push_back(line);
	}

	while (getline(in2, line))
	{
		tgtSent.push_back(line);
	}

	int pos;
	int count=0;
	int i=0;
	for (i=0; i<srcSent.size(); i++)
	{
		pos = srcSent[i].find("<srcset");
		if (pos != string::npos)
		{
			srcSent[i].replace(0,7,"<tstset");
			out << srcSent[i] << endl;
			continue;
		}

		pos = srcSent[i].find("<doc");
		if (pos != string::npos)
		{
			srcSent[i].erase(srcSent[i].size()-1,1);
			srcSent[i] += " site=\"Camel\">";
			out << srcSent[i] << endl;
			continue;
		}

		pos = srcSent[i].find("<s ");
		if (pos != string::npos)
		{
			int start = srcSent[i].find_first_of(">",0)+1;
			int spp = srcSent[i].find_last_of("<",srcSent[i].size());

			if (count < tgtSent.size())
				srcSent[i].replace(start,spp-start,tgtSent[count++]);
                        out << srcSent[i] << endl;  
			continue;
		}
		if(i != srcSent.size()-1)
		    out << srcSent[i] << endl;
	}
	out << "</tstset>" << endl;

	in1.close();
	in2.close();
	out.close();
}

/************************************************************************/
/* Change the first letter to its upper case                            */
/************************************************************************/
void Decoder::TrueCase(string &s)
{
	if (s[0] >='a' && s[0] <='z')
		s[0] -= 32;
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?