⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 pdfbthmt.cc

📁 matlab官方网站中的用于图像融合技术的contourlet变换源代码
💻 CC
📖 第 1 页 / 共 3 页
字号:
void THMT::generate_one(tree<double> &aTree, int& idum, double initval)
{
    tree<int> states(1, nCh, nLev);
    vector<double> vprob(M);
    register int J, i, nNode, m,mm;
    register double mean, stdv;

    aTree[0][0] = initval;

    /********* WRONG!!!
    // Find initial states from given initial value
    for (m = 0; m < M; m++)
	vprob[m] = model_trans[0][m][0] *
	    compute_g(initval, model_mean[0][m], model_stdv[0][m]);

    // It is the state with highest probability
    states[0][0] = 0;
    double maxprob = vprob[0];

    for (m = 1; m < M; m++)
    {
	if (vprob[m] > maxprob)
	{
	    states[0][0] = m;
	    maxprob = vprob[m];
	}
    }
    **********/

    /***** HACK (before the end of the Millennium!!!) *****/
    if (M != 2)
      mexErrMsgTxt("Only works for 2 states");

    int smallState = (model_stdv[0][0][0] < model_stdv[0][0][1]) ? 0 : 1;

    // Cumunative probability
    double cumprob = 0.0;
    for (m = 0; m < M; m++)
	cumprob += model_trans[0][0][m][0] *
	    Psi(fabs(initval - model_mean[0][0][m]) / model_stdv[0][0][m]);

    if (cumprob > (0.5 + 0.5 * model_trans[0][0][smallState][0]))
	states[0][0] = smallState;
    else
	states[0][0] = 1 - smallState;	// largeState

    /***** END HACK *****/

  
    // All others
    for (J = 1; J < nLev; J++) 
    {
	for (i = 0, nNode = aTree[J].size(); i < nNode; i++) 
	{
	  mm = subbandtree[J][i];

	  // build vector prob
	  for (m = 0; m < M; m++)
	    vprob[m] = model_trans[J][mm][m][states[J-1][i/nCh]];

	  states[J][i] = ranind(vprob, idum);
	  mean = model_mean[J][mm][states[J][i]];
	  stdv = model_stdv[J][mm][states[J][i]];
	  aTree[J][i] = mean + stdv * rangas(idum);
	}
    }

}


//-----------------------------------------------------------------------------
tree<double>* THMT::denoise(double nvar, tree<double>* source, 
			    const mxArray* stateprob)
{
  double temp = 0, *doubleptr;
  int J, i, m, mm, nNode;
  mxArray* stateprobcell;

  // Read data from file
  obs = new tree<double> (*source);
  nObs = obs->nrt();

  state_prob = tree< vector<double> >(nObs, nCh, nLev, 
				      vector<double>(M));
  for (J=0; J<nLev; J++){
    stateprobcell = mxGetCell(stateprob, J);
    doubleptr = mxGetPr(stateprobcell);
    for (i = 0, nNode = state_prob[J].size(); i < nNode; i++) {
      for (m = 0; m < M; m++){
	state_prob[J][i][m] = doubleptr[m*nNode+i];
      }
    }
  }
 
  for (J = 0; J < nLev; J++) {
    for (i = 0, nNode = state_prob[J].size(); i < nNode; i++) {
      mm = subbandtree[J][i%ipow(nCh,J)];
      for (m = 0; m < M; m++){
	temp += state_prob[J][i][m]*model_stdv[J][mm][m]*model_stdv[J][mm][m]
	  /(nvar+model_stdv[J][mm][m]*model_stdv[J][mm][m]);
      }
      (*obs)[J][i] = (*obs)[J][i]*temp;
      temp = 0;
    }
  }

  return obs;
}

//-----------------------------------------------------------------------------
void THMT::dump_model(char *filename)
{
    FILE *fp;
    register int J, MM, m, mm, n;

    fp = fopen(filename,"a");
    if (!fp) 
      mexErrMsgTxt("ERROR: can not open for writing");

    fprintf(fp, "nStates: %d\n", M);

    fprintf(fp, "nLevels: %d\n", nLev);

    if (zeromean)
	fprintf(fp, "zeroMean: yes\n");
    else 
	fprintf(fp, "zeroMean: no\n");

    // Initial probs
    fprintf(fp, "\n");
    for (m = 0; m < M; m++)
	fprintf(fp, "%f ", model_trans[0][0][m][0]);
    fprintf(fp, "\n\n");

    // Trans probs
    for (J =  1; J < nLev; J++)
    { 
      for (m = 0; m < M; m++){
	for (mm = 0; mm < model_trans[J].size(); mm++) 
	    for (n = 0; n < M; n++)
	      fprintf(fp, "%f ", model_trans[J][mm][m][n]);
	fprintf(fp, "\n");
      }
      fprintf(fp, "\n");
    }
    fprintf(fp, "\n");

    // Mean
    if (!zeromean){
	for (J = 0; J < nLev; J++)
	{
	  for (mm=0; mm < model_mean[J].size(); mm++)
	    for (m=0; m<M; m++)
	      fprintf(fp, "%f ", model_mean[J][mm][m]);

	  fprintf(fp, "\n");
	}
	fprintf(fp, "\n");
    }

    // Standard deviation
    for (J =  0; J < nLev; J++)
    {
      for (mm=0; mm < model_stdv[J].size(); mm++)
	for (m = 0; m < M; m++) 
	    fprintf(fp, "%f ", model_stdv[J][mm][m]);
	  
      fprintf(fp, "\n");
    }
    fprintf(fp, "\n");

    fclose(fp);
}

//-----------------------------------------------------------------------------
void THMT::dump_model_struct(const mxArray* model)
{
    double* doubleptr;
    mxArray* pointer, *pointer2, *pointer3, *assignptr;
    int numfields, numels;
    register int J, MM, m, mm, n;

    numfields = mxGetNumberOfFields(model);
    numels = mxGetNumberOfElements(model);
    if (((numfields != 6) && (zeromean)) || 
	((numfields != 7) && (!zeromean)))
      mexErrMsgTxt("ERROR: number of fields in struct model is incorrect");

    if (numels != 1)
      mexErrMsgTxt("ERROR: Too many elements");

    if (strcmp(mxGetFieldNameByNumber(model, 0), "nstates") != 0)
      mexErrMsgTxt("Field 0 has wrong name");

    pointer = mxGetFieldByNumber(model, 0, 0);
    doubleptr = mxGetPr(pointer);
    *doubleptr = (double)M;

    if (strcmp(mxGetFieldNameByNumber(model, 1), "nlevels") != 0)
      mexErrMsgTxt("Field 1 has wrong name");    

    pointer = mxGetFieldByNumber(model, 0, 1);
    doubleptr = mxGetPr(pointer);
    *doubleptr = (double)nLev;

    if (strcmp(mxGetFieldNameByNumber(model, 2), "zeromean") != 0)
      mexErrMsgTxt("Field 2 has wrong name");
    if (zeromean){
        assignptr = mxCreateString("yes");
        mxSetFieldByNumber((mxArray*)model, 0, 2, assignptr);
    }
    else {
        assignptr = mxCreateString("no");
        mxSetFieldByNumber((mxArray*)model, 0, 2, assignptr);
    }

    if (strcmp(mxGetFieldNameByNumber(model, 3), "rootprob") != 0)
      mexErrMsgTxt("Field 3 has wrong name");
    pointer = mxGetFieldByNumber(model, 0, 3);
    if (pointer == NULL) {
      mexPrintf("%s%d\n",
    		"FIELD:", 3);
      mexErrMsgTxt("Above field is empty!"); 
    }
    doubleptr = mxGetPr(pointer);
    for (m = 0; m < M; m++)
        doubleptr[m] = model_trans[0][0][m][0];
 
    if (strcmp(mxGetFieldNameByNumber(model, 4), "transprob") != 0)
      mexErrMsgTxt("Field 4 has wrong name");
    pointer = mxGetFieldByNumber(model, 0, 4);
    if (pointer == NULL) {
      mexPrintf("%s%d\n",
    		"FIELD:", 4);
      mexErrMsgTxt("Above field is empty!"); 
    }

    // Trans probs
    for (J =  1; J < nLev; J++)
    { 
      pointer2 = mxGetCell(pointer,J-1);
      for (m = 0; m < M; m++){
	for (mm = 0; mm < model_trans[J].size(); mm++) {
	    pointer3 = mxGetCell(pointer2, mm);
	    doubleptr = mxGetPr(pointer3);
	    for (n = 0; n < M; n++)
	      doubleptr[n*M+m] = model_trans[J][mm][m][n];
	}
      }
    }

    // Mean
    if (!zeromean)
    {
        if (strcmp(mxGetFieldNameByNumber(model, 5), "mean") != 0)
	  mexErrMsgTxt("Field 5 has wrong name");
        pointer = mxGetFieldByNumber(model, 0, 5);
	if (pointer == NULL) {
	  mexPrintf("%s%d\n",
		    "FIELD:", 5);
	  mexErrMsgTxt("Above field is empty!");
	}
	for (J = 0; J < nLev; J++)
	{
	  pointer2 = mxGetCell(pointer, J);
	  for (mm=0; mm < model_mean[J].size(); mm++){
	    pointer3 = mxGetCell(pointer2, mm);
	    doubleptr = mxGetPr(pointer3);
	    for (m=0; m<M; m++)
	      doubleptr[m] = model_mean[J][mm][m];
	  }
	}
    }

    if(!zeromean){
      if (strcmp(mxGetFieldNameByNumber(model, 6), "stdv") != 0)
	mexErrMsgTxt("Field 6 has wrong name");
      pointer = mxGetFieldByNumber(model, 0, 6);
      if (pointer == NULL) {
	mexPrintf("%s%d\n",
		  "FIELD:", 6);
	mexErrMsgTxt("Above field is empty!");
      }
    }
    else{
      if (strcmp(mxGetFieldNameByNumber(model, 5), "stdv") != 0)
	mexErrMsgTxt("Field 5 has wrong name");
      pointer = mxGetFieldByNumber(model, 0, 5);
      if (pointer == NULL) {
	mexPrintf("%s%d\n",
		  "FIELD:", 5);
	mexErrMsgTxt("Above field is empty!");
      } 
    }

    // Standard deviation
    for (J =  0; J < nLev; J++)
    {
      pointer2 = mxGetCell(pointer, J);
      for (mm=0; mm < model_stdv[J].size(); mm++){
	pointer3 = mxGetCell(pointer2, mm);
	doubleptr = mxGetPr(pointer3);
	for (m = 0; m < M; m++) 
	    doubleptr[m] = model_stdv[J][mm][m];
      }
    }
}

//-----------------------------------------------------------------------------
void THMT::dump_state_prob(const mxArray* stateprob)
{
    int i, J, m, nNode;
    mxArray* tempptr;
    double* dataptr;

    for (J=0; J<nLev; J++){
      tempptr = mxGetCell(stateprob, J);
      dataptr = mxGetPr(tempptr);
      for (i = 0, nNode = state_prob[J].size(); i < nNode; i++) 
	for (m = 0; m < M; m++)
	  dataptr[m*state_prob[J].size()+i] = state_prob[J][i][m];
    }
}  
//-----------------------------------------------------------------------------
double KLD_est(THMT model1, THMT model2, int nObservations)
{
    tree<double> genTree;
    double logden1, logden2;
    
    model1.generate_data(genTree, nObservations);

    logden1 = model1.batch_test(&genTree);
    logden2 = model2.batch_test(&genTree);
    
    return (logden1 - logden2);
}


//-----------------------------------------------------------------------------
// Compute the Kullback-Leibler distance between two discrete 
// probality mass functions
double KLD_disc(const vector<double>& prob1, const vector<double>& prob2)
{
    if (prob1.size() != prob2.size())
	cerr << "KLD_disc: Two probability vectors have different length"
	     << endl;

    double d = 0.0;

    for (int i = 0; i < prob1.size(); i++)
	if ((prob1[i] != 0.0) && (prob2[0] != 0.0))
	    d += prob1[i] * log(prob1[i] / prob2[i]);

    return d;
}


//-----------------------------------------------------------------------------
// Compute the Kullback-Leibler distance between two continuous 
// Gaussian probability desity functions
double KLD_gauss(double mean1, double stdv1, double mean2, double stdv2)
{
    double r1 = stdv1 / stdv2;
    double r2 = (mean1 - mean2) / stdv2;

    return 0.5 * (-2*log(r1) - 1 + r1*r1 + r2*r2); 
}


//-----------------------------------------------------------------------------
double KLD_upb(THMT model1, THMT model2)
{
    int J, m, n, dir, maxdir;
    int M, nCh, nLev;
    double test_sum;
   
    M = model1.M;
    if (model2.M != M)
	mexErrMsgTxt("KLD_upb: Incompatible models.");

    nCh = model1.nCh;
    if (model2.nCh != nCh)
	mexErrMsgTxt("KLD_upb: Incompatible models.");

    nLev = model1.nLev;
    if (model2.nLev != nLev)
	mexErrMsgTxt("KLD_upb: Incompatible models.");

    // assume the lowest level has the largest number of directions
    maxdir = model1.model_stdv[nLev-1].size();


    vector<vector<double> > D(maxdir, vector<double>(M));
    vector<vector<double> > d(maxdir, vector<double>(M));
    vector<double> trans1(M);
    vector<double> trans2(M);
 
    // Initial: lowest level
    J = nLev - 1;
    for (dir = 0; dir < model1.model_stdv[J].size(); dir++){
      for (m = 0; m < M; m++)
	D[dir][m] = KLD_gauss(model1.model_mean[J][dir][m], 
			 model1.model_stdv[J][dir][m],
			 model2.model_mean[J][dir][m], 
			 model2.model_stdv[J][dir][m]);
    } 
      
    // DEBUG
    //mexPrintf("%s\n", "Lowest level: ");
    //for (dir = 0; dir < model1.model_stdv[J].size(); dir++)
    //  for (m = 0; m < M; m++)
    //	mexPrintf("%d %d %f \n", dir, m, D[dir][m]);

    // Induction:
    for (J = nLev-1; J > 0; J--) {
      for (dir = 0; dir < model1.model_stdv[J-1].size(); dir++){
	for (m = 0; m < M; m++) {

	  d[dir][m] = KLD_gauss(model1.model_mean[J-1][dir][m], 
				model1.model_stdv[J-1][dir][m],
				model2.model_mean[J-1][dir][m], 
				model2.model_stdv[J-1][dir][m]);
 
	  if ( model1.model_stdv[J].size() == model1.model_stdv[J-1].size()){
	    for (n=0; n<M; n++){
	      trans1[n] = model1.model_trans[J][dir][n][m];
	      trans2[n] = model2.model_trans[J][dir][n][m];
	    }
	    d[dir][m] += nCh * KLD_disc(trans1,trans2);
	  }
	  else if (model1.model_stdv[J].size() == 
		   2*model1.model_stdv[J-1].size()){
	    for (n=0; n<M; n++){
	      trans1[n] = model1.model_trans[J][dir*2][n][m];
	      trans2[n] = model2.model_trans[J][dir*2][n][m];
	    }

	    d[dir][m] += nCh/2 * KLD_disc(trans1,trans2);
	    for (n=0; n<M; n++){
	      trans1[n] = model1.model_trans[J][dir*2+1][n][m];
	      trans2[n] = model2.model_trans[J][dir*2+1][n][m];
	    }
	    d[dir][m] += nCh/2 * KLD_disc(trans1,trans2);
	  }
	  else
	    mexErrMsgTxt("Error: Multiple parents for one child");


	  for (n = 0; n < M; n++)
	    if ( model1.model_stdv[J].size() == model1.model_stdv[J-1].size())
	      d[dir][m] += nCh * model1.model_trans[J][dir][n][m] 
		* D[dir][n]; 
	    else {
	      d[dir][m] += nCh/2 * model1.model_trans[J][dir*2][n][m] 
		* D[dir*2][n];
	      d[dir][m] += nCh/2 * model1.model_trans[J][dir*2+1][n][m] 
		* D[dir*2+1][n];
	    }
	}
      }
      // DEBUG:

      //for (dir = 0; dir < model1.model_stdv[J-1].size(); dir++)
      //	for (m = 0; m < M; m++)
      //	  mexPrintf("%d %d %f %f\n", dir, m, D[dir][m], d[dir][m]);

      // updating the temporary distance vector D
      D = d;
    }
		 
    // Final:
    double dist;
    for (n=0; n<M; n++){
      trans1[n] = model1.model_trans[0][0][n][0];
      trans2[n] = model2.model_trans[0][0][n][0];
    }

    dist = KLD_disc(trans1, trans2);

    // DEBUG
    //mexPrintf("%s %f", "Final: KLD_disc = ", dist);

    for (m = 0; m < M; m++)
	dist += model1.model_trans[0][0][m][0] * D[0][m];

    return dist;		
}

/*********************************************************************
    TEMPLATES INSTANCIATION !!!
*************************************************************/
/*template class matrix<float>;

template class vector<double>;
template class tree<double>;
template class tree<int>;
template class tree<vector<double> >; */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -