📄 chmm.cpp
字号:
fprintf(fp,"\n");
fclose(fp);
// Clean up
delete[] lastLogP;
delete[] currLogP;
for ( i = 0; i < size; i++)
{
delete[] path[i];
}
delete[] path;
prob = exp(prob / size);
return prob;
}
/* SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...
*/
void CHMM::Init(const char* sampleFileName,int VALIDATING)
{
//--- Debug ---//
//DumpSampleFile(sampleFileName);
// Check the sample file
ifstream sampleFile(sampleFileName, ios_base::binary);
_ASSERT(sampleFile);
int i,j,t;
int size = 0;
int dim = 0;
sampleFile.read((char*)&size, sizeof(int)); //序列数
sampleFile.read((char*)&dim, sizeof(int));
_ASSERT(size >= 3);
_ASSERT(dim == m_stateModel[0]->GetDimNum());
//这里为从左到右型,第一个状态的初始概率为0.5, 其他状态的初始概率之和为0.5,
//每个状态到自身的转移概率为0.3, 到下一个状态的转移概率为0.4, 到其他状态的转移概率之和为0.4.
//此处的初始化主要是对 混合高斯模型进行初始化,
for ( i = 0; i < m_stateNum; i++)
{
// The initial probabilities
if(i<=2)
m_stateInit[i]=0.5/3.0;
else
m_stateInit[i] = 0.5 / float(m_stateNum-3);
// The transition probabilities
for ( j = 0; j <= m_stateNum; j++)
{
if((i==j)||(j==i+1))
m_stateTran[i][j]=0.3;
else
m_stateTran[i][j] = 0.4 /m_stateNum;
}
}
vector<double*> *gaussseq;
gaussseq= new vector<double*>[m_stateNum];
for ( i = 0; i < size; i++)//序列数
{
int seq_size = 0;
sampleFile.read((char*)&seq_size, sizeof(int)); //序列的长度
double r=float(seq_size)/float(m_stateNum);
for ( j = 0; j < seq_size; j++)
{
double* x = new double[dim];
sampleFile.read((char*)x, sizeof(double) * dim);
gaussseq[int(j/r)].push_back(x); //对每个序列,根据序列长度和状态数确定每个状态对应的特征数,然后把序列均匀赋到状态中
if(VALIDATING)
{
double* xx = new double[3];
sampleFile.read((char*)xx, sizeof(double) * 3);
delete xx;
}
}
}
char** stateFileName = new char*[m_stateNum];
ofstream* stateFile = new ofstream[m_stateNum];
int* stateDataSize = new int[m_stateNum];
for ( i = 0; i < m_stateNum; i++)
{
stateFileName[i] = new char[20];
ostrstream str(stateFileName[i], 20);
str << "$chmm_s" << i << ".tmp" << '\0';
}
for ( i = 0; i < m_stateNum; i++)
{
stateFile[i].open(stateFileName[i], ios_base::binary);
stateDataSize[i]=gaussseq[i].size();
stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
stateFile[i].write((char*)&dim, sizeof(int));
double* x = new double[dim];
for( j=0;j<stateDataSize[i];j++)
{
x=(double*)gaussseq[i].at(j);
stateFile[i].write((char*)x, sizeof(double) * dim);
}
delete x;
stateFile[i].close();
m_stateModel[i]->Train(stateFileName[i]);
gaussseq[i].clear();
}
for ( i = 0; i < m_stateNum; i++)
delete[] stateFileName[i];
delete[] stateFileName;
delete[] stateFile;
delete[] stateDataSize;
delete[] gaussseq;
}
/* SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...
*/
void CHMM::Train(const char* sampleFileName,int VALIDATING)
{
Init(sampleFileName,VALIDATING);
//--- Debug ---//
//DumpSampleFile(sampleFileName);
// Check the sample file
ifstream sampleFile(sampleFileName, ios_base::binary);
_ASSERT(sampleFile);
int i,j,t;
int size = 0;
int dim = 0;
sampleFile.read((char*)&size, sizeof(int));
sampleFile.read((char*)&dim, sizeof(int));
_ASSERT(size >= 3);
_ASSERT(dim == m_stateModel[0]->GetDimNum());
// Buffer for new model
int* stateInitNum = new int[m_stateNum];
int** stateTranNum = new int*[m_stateNum];
char** stateFileName = new char*[m_stateNum];
ofstream* stateFile = new ofstream[m_stateNum];
int* stateDataSize = new int[m_stateNum];
for ( i = 0; i < m_stateNum; i++)
{
stateTranNum[i] = new int[m_stateNum + 1];
stateFileName[i] = new char[20];
ostrstream str(stateFileName[i], 20);
str << "$chmm_s" << i << ".tmp" << '\0';
}
bool loop = true;
double currL = 0;
double lastL = 0;
int iterNum = 0;
int unchanged = 0;
vector<int> state;
vector<double*> seq;
while (loop)
{
lastL = currL;
currL = 0;
// Clear buffer and open temp data files
for ( i = 0; i < m_stateNum; i++)
{
stateDataSize[i] = 0;
stateFile[i].open(stateFileName[i], ios_base::binary);
stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
stateFile[i].write((char*)&dim, sizeof(int));
memset(stateTranNum[i], 0, sizeof(int) * (m_stateNum + 1));
}
memset(stateInitNum, 0, sizeof(int) * m_stateNum);
// Predict: obtain the best path
sampleFile.seekg(sizeof(int) * 2, ios_base::beg);
for ( i = 0; i < size; i++)
{
int seq_size = 0;
sampleFile.read((char*)&seq_size, sizeof(int));
for ( j = 0; j < seq_size; j++)
{
double* x = new double[dim];
sampleFile.read((char*)x, sizeof(double) * dim);
seq.push_back(x);
if(VALIDATING)
{
double* xx = new double[dim];
sampleFile.read((char*)xx, sizeof(double) * 3);
delete xx;
}
}
currL += LogProb(Decode(seq, state));
stateInitNum[state[0]]++;
for ( j = 0; j < seq_size; j++)
{
stateFile[state[j]].write((char*)seq[j], sizeof(double) * dim);
stateDataSize[state[j]]++;
if (j > 0)
{
stateTranNum[state[j-1]][state[j]]++;
}
}
stateTranNum[state[j-1]][m_stateNum]++; // Final state
for ( j = 0; j < seq_size; j++)
{
delete[] seq[j];
}
state.clear();
seq.clear();
}
currL /= size;
// Close temp data files
for ( i = 0; i < m_stateNum; i++)
{
stateFile[i].seekp(0, ios_base::beg);
stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
stateFile[i].close();
}
// Reestimate: stateModel, stateInit, stateTran
int count = 0;
for ( j = 0; j < m_stateNum; j++)
{
if (stateDataSize[j] > m_stateModel[j]->GetMixNum() * 2)
{
//m_stateModel[j]->DumpSampleFile(stateFileName[j]);
m_stateModel[j]->Train(stateFileName[j]);
}
count += stateInitNum[j];
}
for ( j = 0; j < m_stateNum; j++)
{
m_stateInit[j] = 1.0 * stateInitNum[j] / count;
}
for ( i = 0; i < m_stateNum; i++)
{
count = 0;
for ( j = 0; j < m_stateNum + 1; j++)
{
count += stateTranNum[i][j];
}
if (count > 0)
{
for ( j = 0; j < m_stateNum + 1; j++)
{
m_stateTran[i][j] = 1.0 * stateTranNum[i][j] / count;
}
}
}
// Terminal conditions
iterNum++;
unchanged = (currL - lastL < m_endError * fabs(lastL)) ? (unchanged + 1) : 0;
if (iterNum >= m_maxIterNum || unchanged >= 3)
{
loop = false;
}
//DEBUG
//cout << "Iter: " << iterNum << ", Average Log-Probability: " << currL << endl;
}
for ( i = 0; i < m_stateNum; i++)
{
delete[] stateTranNum[i];
delete[] stateFileName[i];
}
delete[] stateTranNum;
delete[] stateFileName;
delete[] stateFile;
delete[] stateInitNum;
delete[] stateDataSize;
//对初始概率进行重调整 , 前1/3的状态都有机会
for ( i = 0; i < m_stateNum; i++)
{
// The initial probabilities
if(i*3<m_stateNum)
m_stateInit[i]=0.9/float(m_stateNum/3);
else
m_stateInit[i] = 0.1 /float(m_stateNum*2/3);
}
}
double CHMM::getTransProb(int i,int j)
{
if(i<0||i>m_stateNum||j<0||j>m_stateNum)
return -100;
return LogProb(m_stateTran[i][j]);
}
/* SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...
*/
void CHMM::DumpSampleFile(const char* fileName)
{
ifstream sampleFile(fileName, ios_base::binary);
_ASSERT(sampleFile);
int size = 0;
int i,j,t;
sampleFile.read((char*)&size, sizeof(int));
cout << size << endl;
int dim = 0;
sampleFile.read((char*)&dim, sizeof(int));
cout << dim << endl;
double* f = new double[dim];
for ( i = 0; i < size; i++)
{
int seq_size = 0;
sampleFile.read((char*)&seq_size, sizeof(int));
cout << seq_size << endl;
for ( j = 0; j < seq_size; j++)
{
sampleFile.read((char*)f, sizeof(double) * dim);
for (int d = 0; d < dim; d++)
{
cout << f[d] << " ";
}
cout << endl;
}
}
sampleFile.close();
delete[] f;
}
double CHMM::LogProb(double p)
{
return (p > 1e-20) ? log10(p) : -20;
}
ostream& operator<<(ostream& out, CHMM& hmm)
{
int i,j,t;
out << "<CHMM>" << endl;
out << "<StateNum> " << hmm.m_stateNum << " </StateNum>" << endl;
for (i = 0; i < hmm.m_stateNum; i++)
{
out << *hmm.m_stateModel[i];
}
out << "<Init> ";
for ( i = 0; i < hmm.m_stateNum; i++)
{
out << hmm.m_stateInit[i] << " ";
}
out << "</Init>" << endl;
out << "<Tran>" << endl;
for ( i = 0; i < hmm.m_stateNum; i++)
{
for ( j = 0; j < hmm.m_stateNum + 1; j++)
{
out << hmm.m_stateTran[i][j] << " ";
}
out << endl;
}
out << "</Tran>" << endl;
out << "</CHMM>" << endl;
return out;
}
istream& operator>>(istream& in, CHMM& hmm)
{
char label[20];
int i,j,t;
in >> label;
_ASSERTE(strcmp(label, "<CHMM>") == 0);
hmm.Dispose();
in >> label >> hmm.m_stateNum >> label; // "<StateNum>"
hmm.Allocate(hmm.m_stateNum);
for ( i = 0; i < hmm.m_stateNum; i++)
{
in >> *hmm.m_stateModel[i];
}
in >> label; // "<Init>"
for ( i = 0; i < hmm.m_stateNum; i++)
{
in >> hmm.m_stateInit[i];
}
in >> label;
in >> label; // "<Tran>"
for ( i = 0; i < hmm.m_stateNum; i++)
{
for ( j = 0; j < hmm.m_stateNum + 1; j++)
{
in >> hmm.m_stateTran[i][j];
}
}
in >> label;
in >> label; // "</CHMM>"
return in;
}
void CHMM::TextTransform(CString InputText, CString OutputBinaryText)
{
CString InputTextPath = InputText;
CString OutputBinaryTextPath = OutputBinaryText;
ifstream Input(InputTextPath);
ofstream Output(OutputBinaryTextPath,ios_base::binary);
int seq_num=0; //总序列长度,int型
int dim=0; //特征维数,int型
int seq_size=0; //各个序列包含的特征数,int型
Input>>seq_num;
Input>>dim;
Output.write((char*)&seq_num,sizeof(int));
Output.write((char*)&dim,sizeof(int));
double *pt_feature;
pt_feature = new double[dim]; //别忘了释放内存!!!
for(int i=0; i<seq_num; i++)
{
Input>>seq_size;
Output.write((char*)&seq_size,sizeof(int));
for(int j=0;j<seq_size;j++)
{
for(int k=0;k<dim;k++)
{
Input>>pt_feature[k];
}
for(int t=0;t<seq_size;t++)
{
Output<<pt_feature[t];
pt_feature[t]=0;
}
}
}
delete []pt_feature; //勿忘我!!!
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -