staticdata.cpp.svn-base

来自「解码器是基于短语的统计机器翻译系统的核心模块」· SVN-BASE 代码 · 共 746 行 · 第 1/2 页

SVN-BASE
746
字号
			}			//update the distortionModelWeight vector to remove these weights			for(size_t i=numWeightsInTable; i<distortionModelWeights.size(); i++)			{				newLexWeights.push_back(distortionModelWeights[i]);			}			distortionModelWeights = newLexWeights;					}		assert(m_lexWeights.size() == numWeightsInTable);		//the end result should be a weight vector of the same size as the user configured model		//			TRACE_ERR( "distortion-weights: ");		//for(size_t weight=0; weight<m_lexWeights.size(); weight++)		//{		//	TRACE_ERR( m_lexWeights[weight] << "\t");		//}		//TRACE_ERR( endl);		// loading the file		std::string	filePath= specification[3];		PrintUserTime(string("Start loading distortion table ") + filePath);		m_reorderModels.push_back(new LexicalReordering(filePath, orientation, direction, condition, m_lexWeights, input, output));	}		return true;}bool StaticData::LoadLanguageModels(){	if (m_parameter->GetParam("lmodel-file").size() > 0)	{		// weights		vector<float> weightAll = Scan<float>(m_parameter->GetParam("weight-l"));				//TRACE_ERR( "weight-l: ");		//		for (size_t i = 0 ; i < weightAll.size() ; i++)		{			//	TRACE_ERR( weightAll[i] << "\t");			m_allWeights.push_back(weightAll[i]);		}		//TRACE_ERR( endl);		  // initialize n-gram order for each factor. populated only by factored lm		const vector<string> &lmVector = m_parameter->GetParam("lmodel-file");		for(size_t i=0; i<lmVector.size(); i++) 		{			vector<string>	token		= Tokenize(lmVector[i]);			if (token.size() != 4 )			{				UserMessage::Add("Expected format 'LM-TYPE FACTOR-TYPE NGRAM-ORDER filePath'");				return false;			}			// type = implementation, SRI, IRST etc			LMImplementation lmImplementation = static_cast<LMImplementation>(Scan<int>(token[0]));						// factorType = 0 = Surface, 1 = POS, 2 = Stem, 3 = Morphology, etc			vector<FactorType> 	factorTypes		= Tokenize<FactorType>(token[1], ",");						// nGramOrder = 2 = bigram, 3 = trigram, etc			size_t nGramOrder = Scan<int>(token[2]);						string &languageModelFile = token[3];			PrintUserTime(string("Start loading LanguageModel ") + languageModelFile);						LanguageModel *lm = LanguageModelFactory::CreateLanguageModel(lmImplementation, factorTypes                                        									, nGramOrder, languageModelFile, weightAll[i], m_factorCollection);      if (lm == NULL)       {      	UserMessage::Add("no LM created. We probably don't have it compiled");      	return false;      }			m_languageModel.push_back(lm);		}	}  // flag indicating that language models were loaded,  // since phrase table loading requires their presence  m_fLMsLoaded = true;  PrintUserTime("Finished loading LanguageModels");  return true;}bool StaticData::LoadGenerationTables(){	if (m_parameter->GetParam("generation-file").size() > 0) 	{		const vector<string> &generationVector = m_parameter->GetParam("generation-file");		const vector<float> &weight = Scan<float>(m_parameter->GetParam("weight-generation"));		TRACE_ERR( "weight-generation: ");		for (size_t i = 0 ; i < weight.size() ; i++)		{				TRACE_ERR( weight[i] << "\t");		}		TRACE_ERR( endl);		size_t currWeightNum = 0;				for(size_t currDict = 0 ; currDict < generationVector.size(); currDict++) 		{			vector<string>			token		= Tokenize(generationVector[currDict]);			bool oldFormat = (token.size() == 3);			vector<FactorType> 	input		= Tokenize<FactorType>(token[0], ",")													,output	= Tokenize<FactorType>(token[1], ",");      m_maxFactorIdx[1] = CalcMax(m_maxFactorIdx[1], input, output);			string							filePath;			size_t							numFeatures = 1;			if (oldFormat)				filePath = token[2];			else {				numFeatures = Scan<size_t>(token[2]);				filePath = token[3];			}			if (!FileExists(filePath))			{				stringstream strme;				strme << "Generation dictionary '"<<filePath<<"' does not exist!\n";				UserMessage::Add(strme.str());				return false;							}			TRACE_ERR( filePath << endl);			if (oldFormat) {				TRACE_ERR( "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"				             "  [WARNING] config file contains old style generation config format.\n"				             "  Only the first feature value will be read.  Please use the 4-format\n"				             "  form (similar to the phrase table spec) to specify the # of features.\n"				             "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");			}			m_generationDictionary.push_back(new GenerationDictionary(numFeatures));			assert(m_generationDictionary.back() && "could not create GenerationDictionary");			if (!m_generationDictionary.back()->Load(input																		, output																		, m_factorCollection																		, filePath																		, Output				// always target, should we allow source?																		, oldFormat))			{				delete m_generationDictionary.back();				return false;			}			for(size_t i = 0; i < numFeatures; i++) {				assert(currWeightNum < weight.size());				m_allWeights.push_back(weight[currWeightNum++]);			}		}		if (currWeightNum != weight.size()) {			TRACE_ERR( "  [WARNING] config file has " << weight.size() << " generation weights listed, but the configuration for generation files indicates there should be " << currWeightNum << "!\n");		}	}		return true;}bool StaticData::LoadPhraseTables(){	VERBOSE(2,"About to LoadPhraseTables" << endl);	// language models must be loaded prior to loading phrase tables	assert(m_fLMsLoaded);	// load phrase translation tables  if (m_parameter->GetParam("ttable-file").size() > 0)	{		// weights		vector<float> weightAll									= Scan<float>(m_parameter->GetParam("weight-t"));				//TRACE_ERR("weight-t: ");		//for (size_t i = 0 ; i < weightAll.size() ; i++)		//{		//		TRACE_ERR(weightAll[i] << "\t");		//}		//TRACE_ERR( endl;		const vector<string> &translationVector = m_parameter->GetParam("ttable-file");		vector<size_t>	maxTargetPhrase					= Scan<size_t>(m_parameter->GetParam("ttable-limit"));		//TRACE_ERR("ttable-limits: ";copy(maxTargetPhrase.begin(),maxTargetPhrase.end(),ostream_iterator<size_t>(cerr," "));cerr<<"\n");		size_t index = 0;		size_t weightAllOffset = 0;		for(size_t currDict = 0 ; currDict < translationVector.size(); currDict++) 		{			vector<string>                  token           = Tokenize(translationVector[currDict]);			//characteristics of the phrase table			vector<FactorType>      input           = Tokenize<FactorType>(token[0], ",")				,output = Tokenize<FactorType>(token[1], ",");			m_maxFactorIdx[0] = CalcMax(m_maxFactorIdx[0], input);			m_maxFactorIdx[1] = CalcMax(m_maxFactorIdx[1], output);      m_maxNumFactors = std::max(m_maxFactorIdx[0], m_maxFactorIdx[1]) + 1;			string filePath= token[3];			size_t numScoreComponent = Scan<size_t>(token[2]);			// weights for this phrase dictionary			// first InputScores (if any), then translation scores			vector<float> weight;			if(currDict==0 && m_inputType)			{	// TODO. find what the assumptions made by confusion network about phrase table output which makes				// it only work with binrary file. This is a hack 					m_numInputScores=m_parameter->GetParam("weight-i").size();				for(unsigned k=0;k<m_numInputScores;++k)					weight.push_back(Scan<float>(m_parameter->GetParam("weight-i")[k]));			}			else{				m_numInputScores=0;			}						for (size_t currScore = 0 ; currScore < numScoreComponent; currScore++)				weight.push_back(weightAll[weightAllOffset + currScore]);									if(weight.size() - m_numInputScores != numScoreComponent) 			{				stringstream strme;				strme << "Your phrase table has " << numScoreComponent							<< " scores, but you specified " << weight.size() << " weights!";				UserMessage::Add(strme.str());				return false;			}									weightAllOffset += numScoreComponent;			numScoreComponent += m_numInputScores;									assert(numScoreComponent==weight.size());			std::copy(weight.begin(),weight.end(),std::back_inserter(m_allWeights));						PrintUserTime(string("Start loading PhraseTable ") + filePath);			if (!FileExists(filePath+".binphr.idx"))			{									VERBOSE(2,"using standard phrase tables");				PhraseDictionaryMemory *pd=new PhraseDictionaryMemory(numScoreComponent);				if (!pd->Load(input								 , output								 , m_factorCollection								 , filePath								 , weight								 , maxTargetPhrase[index]								 , GetAllLM()								 , GetWeightWordPenalty()								 , *this))				{					delete pd;					return false;				}				m_phraseDictionary.push_back(pd);			}			else 			{				TRACE_ERR( "using binary phrase tables for idx "<<currDict<<"\n");				PhraseDictionaryTreeAdaptor *pd=new PhraseDictionaryTreeAdaptor(numScoreComponent,(currDict==0 ? m_numInputScores : 0));				if (!pd->Load(input,output,m_factorCollection,filePath,weight,									 maxTargetPhrase[index],									 GetAllLM(),									 GetWeightWordPenalty()))				{					delete pd;					return false;				}				m_phraseDictionary.push_back(pd);			}			index++;		}	}		PrintUserTime("Finished loading phrase tables");	return true;}bool StaticData::LoadMapping(){	// mapping	const vector<string> &mappingVector = m_parameter->GetParam("mapping");	DecodeStep *prev = 0;	for(size_t i=0; i<mappingVector.size(); i++) 	{		vector<string>	token		= Tokenize(mappingVector[i]);		if (token.size() == 2) 		{			DecodeType decodeType = token[0] == "T" ? Translate : Generate;			size_t index = Scan<size_t>(token[1]);			DecodeStep* decodeStep = 0;			switch (decodeType) {				case Translate:					if(index>=m_phraseDictionary.size())						{							stringstream strme;							strme << "No phrase dictionary with index "										<< index << " available!";							UserMessage::Add(strme.str());							return false;						}					decodeStep = new DecodeStepTranslation(m_phraseDictionary[index], prev);				break;				case Generate:					if(index>=m_generationDictionary.size())						{							stringstream strme;							strme << "No generation dictionary with index "										<< index << " available!";							UserMessage::Add(strme.str());							return false;						}					decodeStep = new DecodeStepGeneration(m_generationDictionary[index], prev);				break;				case InsertNullFertilityWord:					assert(!"Please implement NullFertilityInsertion.");				break;			}			assert(decodeStep);			m_decodeStepList.push_back(decodeStep);			prev = decodeStep;		} else {			UserMessage::Add("Malformed mapping!");			return false;		}	}		return true;}void StaticData::CleanUpAfterSentenceProcessing() {	for(size_t i=0;i<m_phraseDictionary.size();++i)		m_phraseDictionary[i]->CleanUp();	for(size_t i=0;i<m_generationDictionary.size();++i)		m_generationDictionary[i]->CleanUp();    //something LMs could do after each sentence   LMList::const_iterator iterLM;	for (iterLM = m_languageModel.begin() ; iterLM != m_languageModel.end() ; ++iterLM)	{		LanguageModel &languageModel = **iterLM;    languageModel.CleanUpAfterSentenceProcessing();	}}/** initialize the translation and language models for this sentence     (includes loading of translation table entries on demand, if    binary format is used) */void StaticData::InitializeBeforeSentenceProcessing(InputType const& in) {	for(size_t i=0;i<m_phraseDictionary.size();++i) 	{		m_phraseDictionary[i]->InitializeForInput(in);  }  //something LMs could do before translating a sentence  LMList::const_iterator iterLM;	for (iterLM = m_languageModel.begin() ; iterLM != m_languageModel.end() ; ++iterLM)	{		LanguageModel &languageModel = **iterLM;    languageModel.InitializeBeforeSentenceProcessing();	}  }void StaticData::SetWeightsForScoreProducer(const ScoreProducer* sp, const std::vector<float>& weights){  const size_t id = sp->GetScoreBookkeepingID();  const size_t begin = m_scoreIndexManager.GetBeginIndex(id);  const size_t end = m_scoreIndexManager.GetEndIndex(id);  assert(end - begin == weights.size());  if (m_allWeights.size() < end)    m_allWeights.resize(end);  std::vector<float>::const_iterator weightIter = weights.begin();  for (size_t i = begin; i < end; i++)    m_allWeights[i] = *weightIter++;}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?