⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 chmm.cpp

📁 简单明了的HMM-GMM-KEAMS调用接口
💻 CPP
📖 第 1 页 / 共 2 页
字号:
	fprintf(fp,"\n");
	fclose(fp);

	// Clean up
	delete[] lastLogP;
	delete[] currLogP;
	for ( i = 0; i < size; i++)
	{
		delete[] path[i];
	}
	delete[] path;

	prob = exp(prob / size);
	return prob;
}


/*	SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...
*/
void CHMM::Init(const char* sampleFileName,int VALIDATING)
{
	//--- Debug ---//
	//DumpSampleFile(sampleFileName);

	// Check the sample file
	ifstream sampleFile(sampleFileName, ios_base::binary);
	_ASSERT(sampleFile);
	int i,j,t;

	int size = 0;
	int dim = 0;
	sampleFile.read((char*)&size, sizeof(int));  //序列数
	sampleFile.read((char*)&dim, sizeof(int));
	_ASSERT(size >= 3);
	_ASSERT(dim == m_stateModel[0]->GetDimNum());

	//这里为从左到右型,第一个状态的初始概率为0.5, 其他状态的初始概率之和为0.5,
	//每个状态到自身的转移概率为0.3, 到下一个状态的转移概率为0.4, 到其他状态的转移概率之和为0.4.

	//此处的初始化主要是对 混合高斯模型进行初始化,
	for ( i = 0; i < m_stateNum; i++)
	{
		// The initial probabilities
		if(i<=2)
			m_stateInit[i]=0.5/3.0;
		else
            m_stateInit[i] = 0.5 / float(m_stateNum-3);

		// The transition probabilities
		for ( j = 0; j <= m_stateNum; j++)
		{
			if((i==j)||(j==i+1))
				m_stateTran[i][j]=0.3;
			else
				m_stateTran[i][j] = 0.4 /m_stateNum;
		}
	}

	vector<double*> *gaussseq;
	gaussseq= new vector<double*>[m_stateNum];

	for ( i = 0; i < size; i++)//序列数
	{
		int seq_size = 0;
		sampleFile.read((char*)&seq_size, sizeof(int));  //序列的长度

		double r=float(seq_size)/float(m_stateNum);
		for ( j = 0; j < seq_size; j++)
		{
			double* x = new double[dim];
			sampleFile.read((char*)x, sizeof(double) * dim);
			gaussseq[int(j/r)].push_back(x);    //对每个序列,根据序列长度和状态数确定每个状态对应的特征数,然后把序列均匀赋到状态中                
			if(VALIDATING)
			{
				double* xx = new double[3];
				sampleFile.read((char*)xx, sizeof(double) * 3);
				delete xx;
			}
		}
	}

	char** stateFileName = new char*[m_stateNum];
	ofstream* stateFile = new ofstream[m_stateNum];
	int* stateDataSize = new int[m_stateNum];


	for ( i = 0; i < m_stateNum; i++)
	{
		stateFileName[i] = new char[20];
		ostrstream str(stateFileName[i], 20);
		str << "$chmm_s" << i << ".tmp" << '\0';

	}

	for ( i = 0; i < m_stateNum; i++)
	{
		stateFile[i].open(stateFileName[i], ios_base::binary);
		stateDataSize[i]=gaussseq[i].size();
		stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
		stateFile[i].write((char*)&dim, sizeof(int));
		double* x = new double[dim];
		for( j=0;j<stateDataSize[i];j++)
		{
			x=(double*)gaussseq[i].at(j);
            stateFile[i].write((char*)x, sizeof(double) * dim);   
		}
		delete x;
		stateFile[i].close();
		m_stateModel[i]->Train(stateFileName[i]);
		gaussseq[i].clear();
	}

	for ( i = 0; i < m_stateNum; i++)
		delete[] stateFileName[i];

	delete[] stateFileName;
	delete[] stateFile;
	delete[] stateDataSize;
	delete[] gaussseq;

}

/*	SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...
*/
void CHMM::Train(const char* sampleFileName,int VALIDATING)
{
	
	Init(sampleFileName,VALIDATING);

	//--- Debug ---//
	//DumpSampleFile(sampleFileName);

	// Check the sample file
	ifstream sampleFile(sampleFileName, ios_base::binary);
	_ASSERT(sampleFile);
	int i,j,t;

	int size = 0;
	int dim = 0;
	sampleFile.read((char*)&size, sizeof(int));
	sampleFile.read((char*)&dim, sizeof(int));
	_ASSERT(size >= 3);
	_ASSERT(dim == m_stateModel[0]->GetDimNum());

	// Buffer for new model
	int* stateInitNum = new int[m_stateNum];
	int** stateTranNum = new int*[m_stateNum];
	char** stateFileName = new char*[m_stateNum];
	ofstream* stateFile = new ofstream[m_stateNum];
	int* stateDataSize = new int[m_stateNum];


	for ( i = 0; i < m_stateNum; i++)
	{
		stateTranNum[i] = new int[m_stateNum + 1];
		stateFileName[i] = new char[20];

		ostrstream str(stateFileName[i], 20);
		str << "$chmm_s" << i << ".tmp" << '\0';
	}

	bool loop = true;
	double currL = 0;
	double lastL = 0;
	int iterNum = 0;
	int unchanged = 0;
	vector<int> state;
	vector<double*> seq;

	while (loop)
	{
		lastL = currL;
		currL = 0;

		// Clear buffer and open temp data files
		for ( i = 0; i < m_stateNum; i++)
		{
			stateDataSize[i] = 0;
			stateFile[i].open(stateFileName[i], ios_base::binary);
			stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
			stateFile[i].write((char*)&dim, sizeof(int));

			memset(stateTranNum[i], 0, sizeof(int) * (m_stateNum + 1));
		}
		memset(stateInitNum, 0, sizeof(int) * m_stateNum);

		// Predict: obtain the best path
		sampleFile.seekg(sizeof(int) * 2, ios_base::beg);
		for ( i = 0; i < size; i++)
		{
			int seq_size = 0;
			sampleFile.read((char*)&seq_size, sizeof(int));

			for ( j = 0; j < seq_size; j++)
			{
				double* x = new double[dim];
				sampleFile.read((char*)x, sizeof(double) * dim);
				seq.push_back(x);

				if(VALIDATING)
				{
                    double* xx = new double[dim];
					sampleFile.read((char*)xx, sizeof(double) * 3);
					delete xx;
				}
			}

			currL += LogProb(Decode(seq, state));

			stateInitNum[state[0]]++;
			for ( j = 0; j < seq_size; j++)
			{
				stateFile[state[j]].write((char*)seq[j], sizeof(double) * dim);
				stateDataSize[state[j]]++;

				if (j > 0)
				{
					stateTranNum[state[j-1]][state[j]]++;
				}
			}
			stateTranNum[state[j-1]][m_stateNum]++; // Final state

			for ( j = 0; j < seq_size; j++)
			{
				delete[] seq[j];
			}
			state.clear();
			seq.clear();
		}
		currL /= size;

		// Close temp data files
		for ( i = 0; i < m_stateNum; i++)
		{
			stateFile[i].seekp(0, ios_base::beg);
			stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
			stateFile[i].close();
		}

		// Reestimate: stateModel, stateInit, stateTran
		int count = 0;
		for ( j = 0; j < m_stateNum; j++)
		{
			if (stateDataSize[j] > m_stateModel[j]->GetMixNum() * 2)
			{
				//m_stateModel[j]->DumpSampleFile(stateFileName[j]);
				m_stateModel[j]->Train(stateFileName[j]);
			}
			count += stateInitNum[j];
		}
		for ( j = 0; j < m_stateNum; j++)
		{
			m_stateInit[j] = 1.0 * stateInitNum[j] / count;
		}

		for ( i = 0; i < m_stateNum; i++)
		{
			count = 0;
			for ( j = 0; j < m_stateNum + 1; j++)
			{
				count += stateTranNum[i][j];
			}
			if (count > 0)
			{
				for ( j = 0; j < m_stateNum + 1; j++)
				{
					m_stateTran[i][j] = 1.0 * stateTranNum[i][j] / count;
				}
			}
		}

		// Terminal conditions
		iterNum++;
		unchanged = (currL - lastL < m_endError * fabs(lastL)) ? (unchanged + 1) : 0;
		if (iterNum >= m_maxIterNum || unchanged >= 3)
		{
			loop = false;
		}

		//DEBUG
		//cout << "Iter: " << iterNum << ", Average Log-Probability: " << currL << endl;
	}

	for ( i = 0; i < m_stateNum; i++)
	{
		delete[] stateTranNum[i];
		delete[] stateFileName[i];
	}
	delete[] stateTranNum;
	delete[] stateFileName;
	delete[] stateFile;
	delete[] stateInitNum;
	delete[] stateDataSize;

	//对初始概率进行重调整 , 前1/3的状态都有机会
	for ( i = 0; i < m_stateNum; i++)
	{
		// The initial probabilities
		if(i*3<m_stateNum)
			m_stateInit[i]=0.9/float(m_stateNum/3);
		else
            m_stateInit[i] = 0.1 /float(m_stateNum*2/3);
	}
}
double CHMM::getTransProb(int i,int j)
{
	if(i<0||i>m_stateNum||j<0||j>m_stateNum)
		return -100;
	return LogProb(m_stateTran[i][j]);
}

/*	SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...
*/
void CHMM::DumpSampleFile(const char* fileName)
{
	ifstream sampleFile(fileName, ios_base::binary);
	_ASSERT(sampleFile);

	int size = 0;
	int i,j,t;
	sampleFile.read((char*)&size, sizeof(int));
	cout << size << endl;

	int dim = 0;
	sampleFile.read((char*)&dim, sizeof(int));
	cout << dim << endl;

	double* f = new double[dim];

	for ( i = 0; i < size; i++)
	{
		int seq_size = 0;
		sampleFile.read((char*)&seq_size, sizeof(int));

		cout << seq_size << endl;
		for ( j = 0; j < seq_size; j++)
		{
			sampleFile.read((char*)f, sizeof(double) * dim);
			for (int d = 0; d < dim; d++)
			{
				cout << f[d] << " ";
			}
			cout << endl;
		}
	}
	sampleFile.close();

	delete[] f;
}

double CHMM::LogProb(double p)
{
	return (p > 1e-20) ? log10(p) : -20;
}

ostream& operator<<(ostream& out, CHMM& hmm)
{
	int i,j,t;
	out << "<CHMM>" << endl;
	out << "<StateNum> " << hmm.m_stateNum << " </StateNum>" << endl;

	for (i = 0; i < hmm.m_stateNum; i++)
	{
		out << *hmm.m_stateModel[i];
	}

	out << "<Init> ";
	for ( i = 0; i < hmm.m_stateNum; i++)
	{
		out << hmm.m_stateInit[i] << " ";
	}
	out << "</Init>" << endl;

	out << "<Tran>" << endl;
	for ( i = 0; i < hmm.m_stateNum; i++)
	{
		for ( j = 0; j < hmm.m_stateNum + 1; j++)
		{
			out << hmm.m_stateTran[i][j] << " ";
		}
		out << endl;
	}
	out << "</Tran>" << endl;

	out << "</CHMM>" << endl;
	return out;
}

istream& operator>>(istream& in, CHMM& hmm)
{
	char label[20];
	int i,j,t;
	in >> label;
	_ASSERTE(strcmp(label, "<CHMM>") == 0);

	hmm.Dispose();

	in >> label >> hmm.m_stateNum >> label; // "<StateNum>"

	hmm.Allocate(hmm.m_stateNum);

	for ( i = 0; i < hmm.m_stateNum; i++)
	{
		in >> *hmm.m_stateModel[i];
	}

	in >> label; // "<Init>"
	for ( i = 0; i < hmm.m_stateNum; i++)
	{
		in >> hmm.m_stateInit[i];
	}
	in >> label;

	in >> label; // "<Tran>"
	for ( i = 0; i < hmm.m_stateNum; i++)
	{
		for ( j = 0; j < hmm.m_stateNum + 1; j++)
		{
			in >> hmm.m_stateTran[i][j];
		}
	}
	in >> label;

	in >> label; // "</CHMM>"

	return in;
}
void CHMM::TextTransform(CString InputText, CString OutputBinaryText)
{
    CString InputTextPath = InputText;
	CString OutputBinaryTextPath = OutputBinaryText;
	ifstream Input(InputTextPath);
	ofstream Output(OutputBinaryTextPath,ios_base::binary);
	int seq_num=0;  //总序列长度,int型
	int dim=0;      //特征维数,int型
    int seq_size=0; //各个序列包含的特征数,int型
    
	Input>>seq_num;
	Input>>dim;
	Output.write((char*)&seq_num,sizeof(int));
	Output.write((char*)&dim,sizeof(int));

    double *pt_feature;
	pt_feature = new double[dim];   //别忘了释放内存!!!

	for(int i=0; i<seq_num; i++)
	{
		Input>>seq_size;
        Output.write((char*)&seq_size,sizeof(int));
		for(int j=0;j<seq_size;j++)
		{
		    for(int k=0;k<dim;k++)
			{
			    Input>>pt_feature[k];
			}
            for(int t=0;t<seq_size;t++)
			{
			    Output<<pt_feature[t];
				pt_feature[t]=0;			
			}
		}
	}

    delete []pt_feature;  //勿忘我!!!

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -