hypothesis.cpp.svn-base

来自「解码器是基于短语的统计机器翻译系统的核心模块」· SVN-BASE 代码 · 共 477 行 · 第 1/2 页

SVN-BASE
477
字号
			//cout<<"context factor: "<<languageModel.GetValue(contextFactor)<<endl;			// main loop			size_t endPos = std::min(startPos + nGramOrder - 2															, currEndPos);			for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++)			{				// shift all args down 1 place				for (size_t i = 0 ; i < nGramOrder - 1 ; i++)					contextFactor[i] = contextFactor[i + 1];					// add last factor				contextFactor.back() = &GetWord(currPos);				lmScore	+= languageModel.GetValue(contextFactor);				if (m_lmstats) 					languageModel.GetState(contextFactor, &(*m_lmstats)[lmIdx][nLmCallCount++]);				//cout<<"context factor: "<<languageModel.GetValue(contextFactor)<<endl;					}			// end of sentence			if (m_sourceCompleted.IsComplete())			{				const size_t size = GetSize();				contextFactor.back() = &languageModel.GetSentenceEndArray();					for (size_t i = 0 ; i < nGramOrder - 1 ; i ++)				{					int currPos = (int)(size - nGramOrder + i + 1);					if (currPos < 0)						contextFactor[i] = &languageModel.GetSentenceStartArray();					else						contextFactor[i] = &GetWord((size_t)currPos);				}				if (m_lmstats) {					(*m_lmstats)[lmIdx].resize((*m_lmstats)[lmIdx].size() + 1); // extra space for the last call					lmScore += languageModel.GetValue(contextFactor, &m_languageModelStates[lmIdx], &(*m_lmstats)[lmIdx][nLmCallCount++]);				} else					lmScore	+= languageModel.GetValue(contextFactor, &m_languageModelStates[lmIdx]);			} else {				for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {					for (size_t i = 0 ; i < nGramOrder - 1 ; i++)						contextFactor[i] = contextFactor[i + 1];					contextFactor.back() = &GetWord(currPos);					if (m_lmstats)						languageModel.GetState(contextFactor, &(*m_lmstats)[lmIdx][nLmCallCount++]);				}				m_languageModelStates[lmIdx]=languageModel.GetState(contextFactor);			}		}				m_scoreBreakdown.PlusEquals(&languageModel, lmScore);	}}void Hypothesis::CalcDistortionScore(){	const DistortionScoreProducer *dsp = StaticData::Instance()->GetDistortionScoreProducer();	float distortionScore = dsp->CalculateDistortionScore(			m_prevHypo->GetCurrSourceWordsRange(),			this->GetCurrSourceWordsRange()     );	m_scoreBreakdown.PlusEquals(dsp, distortionScore);}void Hypothesis::ResetScore(){	m_scoreBreakdown.ZeroAll();	m_futureScore = m_totalScore = 0.0f;}/*** * calculate the logarithm of our total translation score (sum up components) */void Hypothesis::CalcScore(const StaticData& staticData, const SquareMatrix &futureScore) {	// DISTORTION COST	CalcDistortionScore();		// LANGUAGE MODEL COST	CalcLMScore(staticData.GetAllLM());	// WORD PENALTY	m_scoreBreakdown.PlusEquals(staticData.GetWordPenaltyProducer(), - (float) m_currTargetWordsRange.GetNumWordsCovered()); 	// FUTURE COST	CalcFutureScore(futureScore);		//LEXICAL REORDERING COST	std::vector<LexicalReordering*> m_reorderModels = staticData.GetReorderModels();	for(unsigned int i = 0; i < m_reorderModels.size(); i++)	{		m_scoreBreakdown.PlusEquals(m_reorderModels[i], m_reorderModels[i]->CalcScore(this));	}	// TOTAL	m_totalScore = m_scoreBreakdown.InnerProduct(staticData.GetAllWeights()) + m_futureScore;}void Hypothesis::CalcFutureScore(const SquareMatrix &futureScore){	const size_t maxSize= numeric_limits<size_t>::max();	size_t	start				= maxSize;	m_futureScore	= 0.0f;	for(size_t currPos = 0 ; currPos < m_sourceCompleted.GetSize() ; currPos++) 	{		if(m_sourceCompleted.GetValue(currPos) == 0 && start == maxSize)		{			start = currPos;		}		if(m_sourceCompleted.GetValue(currPos) == 1 && start != maxSize) 		{//			m_score[ScoreType::FutureScoreEnum] += futureScore[start][currPos - 1];			m_futureScore += futureScore.GetScore(start, currPos - 1);			start = maxSize;		}	}	if (start != maxSize)	{//		m_score[ScoreType::FutureScoreEnum] += futureScore[start][m_sourceCompleted.GetSize() - 1];		m_futureScore += futureScore.GetScore(start, m_sourceCompleted.GetSize() - 1);	}	// add future costs for distortion model	if(StaticData::Instance()->UseDistortionFutureCosts())		m_futureScore += m_sourceCompleted.GetFutureCosts(m_currSourceWordsRange.GetEndPos()) * StaticData::Instance()->GetWeightDistortion();	}const Hypothesis* Hypothesis::GetPrevHypo()const{	return m_prevHypo;}/** * print hypothesis information for pharaoh-style logging */void Hypothesis::PrintHypothesis(const InputType &source, float /*weightDistortion*/, float /*weightWordPenalty*/) const{  TRACE_ERR( "creating hypothesis "<< m_id <<" from "<< m_prevHypo->m_id<<" ( ");  int end = (int)(m_prevHypo->m_targetPhrase.GetSize()-1);  int start = end-1;  if ( start < 0 ) start = 0;  if ( m_prevHypo->m_currTargetWordsRange.GetStartPos() == NOT_FOUND ) {    TRACE_ERR( "<s> ");  }  else {    TRACE_ERR( "... ");  }  if (end>=0) {    WordsRange range(start, end);    TRACE_ERR( m_prevHypo->m_targetPhrase.GetSubString(range) << " ");  }  TRACE_ERR( ")"<<endl);	TRACE_ERR( "\tbase score "<< (m_prevHypo->m_totalScore - m_prevHypo->m_futureScore) <<endl);	TRACE_ERR( "\tcovering "<<m_currSourceWordsRange.GetStartPos()<<"-"<<m_currSourceWordsRange.GetEndPos()<<": "<< source.GetSubString(m_currSourceWordsRange)  <<endl);	TRACE_ERR( "\ttranslated as: "<<m_targetPhrase<<endl); // <<" => translation cost "<<m_score[ScoreType::PhraseTrans];	if (m_wordDeleted) TRACE_ERR( "\tword deleted"<<endl);   //	TRACE_ERR( "\tdistance: "<<GetCurrSourceWordsRange().CalcDistortion(m_prevHypo->GetCurrSourceWordsRange())); // << " => distortion cost "<<(m_score[ScoreType::Distortion]*weightDistortion)<<endl;  //	TRACE_ERR( "\tlanguage model cost "); // <<m_score[ScoreType::LanguageModelScore]<<endl;  //	TRACE_ERR( "\tword penalty "); // <<(m_score[ScoreType::WordPenalty]*weightWordPenalty)<<endl;	TRACE_ERR( "\tscore "<<m_totalScore - m_futureScore<<" + future cost "<<m_futureScore<<" = "<<m_totalScore<<endl);  TRACE_ERR(  "\tunweighted feature scores: " << m_scoreBreakdown << endl);	//PrintLMScores();}void Hypothesis::InitializeArcs(){	// point this hypo's main hypo to itself	SetWinningHypo(this);	if (!m_arcList) return;	// set all arc's main hypo variable to this hypo	ArcList::iterator iter = m_arcList->begin();	for (; iter != m_arcList->end() ; ++iter)	{		Hypothesis *arc = *iter;		arc->SetWinningHypo(this);	}}TO_STRING_BODY(Hypothesis) // friendostream& operator<<(ostream& out, const Hypothesis& hypothesis){		hypothesis.ToStream(out);	// words bitmap	out << "[" << hypothesis.m_sourceCompleted << "] ";		// scores	out << " [total=" << hypothesis.GetTotalScore() << "]";	out << " " << hypothesis.GetScoreBreakdown();	return out;}std::string Hypothesis::GetSourcePhraseStringRep(const vector<FactorType> factorsToPrint) const {	if (!m_prevHypo) { return ""; }	if(m_sourcePhrase) 	{		return m_sourcePhrase->GetSubString(m_currSourceWordsRange).GetStringRep(factorsToPrint);	}	else	{ 		return m_sourceInput.GetSubString(m_currSourceWordsRange).GetStringRep(factorsToPrint);	}	}std::string Hypothesis::GetTargetPhraseStringRep(const vector<FactorType> factorsToPrint) const {	if (!m_prevHypo) { return ""; }	return m_targetPhrase.GetStringRep(factorsToPrint);}std::string Hypothesis::GetSourcePhraseStringRep() const {	vector<FactorType> allFactors;	const size_t maxSourceFactors = StaticData::Instance()->GetMaxNumFactors(Input);	for(size_t i=0; i < maxSourceFactors; i++)	{		allFactors.push_back(i);	}	return GetSourcePhraseStringRep(allFactors);		}std::string Hypothesis::GetTargetPhraseStringRep() const {	vector<FactorType> allFactors;	const size_t maxTargetFactors = StaticData::Instance()->GetMaxNumFactors(Output);	for(size_t i=0; i < maxTargetFactors; i++)	{		allFactors.push_back(i);	}	return GetTargetPhraseStringRep(allFactors);}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?