⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 fams.cpp

📁 mean-shift算法的实现算例
💻 CPP
📖 第 1 页 / 共 3 页
字号:
   int nrel=1;
   for (i=maxm-2; i>=0; i--)
   {
      if (stemp[i]>=npmin)
         nrel++;
      else
         break;
   }
   if (nrel > FAMS_PRUNE_MAXM)
      nrel = FAMS_PRUNE_MAXM;

   // rearange only relevant modes
   mcount2 = new int[nrel];
   cmodes2 = new float[d_*nrel];

   for (i=0; i<nrel; i++)
   {
      cm = istemp[maxm-i-1]; // index
      mcount2[i] = mcount[cm];
      memcpy(cmodes2+i*d_, cmodes+cm*d_, d_*sizeof(float));
      //bgLog("1: %g %g %d\n",cmodes2[i*d_+0],cmodes2[i*d_+d_-1], mcount2[i]);
   }

   delete [] cmodes;
   memset(mcount, 0, nsel_*sizeof(int));
   mcount[0]=1;
   for (i=1; i<nsel_; i+=jm)
      mcount[i] = 1;

   maxm = nrel;

   myPt = nsel_/10;
   for (cm = 1; cm<nsel_; cm++)
   {
      if((cm%myPt)==0)
         bgLog(".");

      if (mcount[cm])
         continue;

      pmodes = modes_+cm*d_;

      // compute closest mode
      cminDist = d_*1e7;
      iminDist = -1;
      for (cref = 0; cref<maxm; cref++)
      {
         cdist = 0;
         ctmodes = cmodes2+cref*d_;
         for (cd=0; cd<d_; cd++)
            cdist += fabs(ctmodes[cd]/mcount2[cref] - pmodes[cd]);
         if (cdist<cminDist)
         {
            cminDist = cdist;
            iminDist = cref;
         }
      }
      // join
      hprune = hmodes_[cm] >> FAMS_PRUNE_HDIV;
      if (cminDist < hprune)
      {
         // aready in, just add
         for (cd=0; cd<d_; cd++)
         {
            cmodes2[iminDist*d_+cd] += pmodes[cd];
         }
         mcount2[iminDist] += 1;
      } else
      {
         // new mode, but discard in second pass
      }
   }

   // put the modes in the order of importance (count)
   for (i=0; i<maxm; i++)
   {
      stemp[i] = mcount2[i];
      istemp[i] = i;
   }
   bgISort(stemp, maxm, istemp); // increasing

   // find number of relevant modes
   nrel=1;
   for (i=maxm-2; i>=0; i--)
   {
      if (stemp[i]>=npmin)
         nrel++;
      else
         break;
   }

   CleanPrunedModes();
   prunedmodes_ = new unsigned short[d_*nrel];
   nprunedmodes_ = new int[nrel];
   unsigned short* cpm;
   npm_ = nrel;

   cpm = prunedmodes_;
   for (i=0; i<npm_; i++)
   {
      nprunedmodes_[i] = stemp[maxm-i-1];
      cm = istemp[maxm-i-1];
      for (cd=0; cd<d_; cd++)
      {
         *(cpm++) = (unsigned short) (cmodes2[cm*d_+cd]/mcount2[cm]);
      }
      //bgLog("2: %d %d\n",prunedmodes_[i*d_+0],prunedmodes_[i*d_+d_-1]);
   }


   delete [] istemp;
   delete [] stemp;

   delete [] cmodes2;
   delete [] mcount2;
   delete [] mcount;


   bgLog("done\n");
   return 1;
}

int FAMS::LoadBandwidths(char* fn)
{
   FILE* fd;
   fd = fopen(fn, "rb");
   if (fd == NULL)
      return 0;
   int n, i;
   fscanf(fd, "%d", &n);
   if (n!=n_)
   {
      fclose(fd);
      return 0;
   }
   float bw;
   float deltaVal = maxVal_-minVal_;
   for (i=0; i<n_; i++)
   {
      fscanf(fd, "%g", &bw);
      points_[i].window_ = (unsigned int) (65535.0*(bw)/deltaVal);
   }
   fclose(fd);
   return 1;
}

void FAMS::SaveBandwidths(char* fn)
{
   FILE* fd;
   fd = fopen(fn, "wb");
   if (fd==NULL)
      return;
   fprintf(fd, "%d\n", n_);
   float bw;
   float deltaVal = maxVal_-minVal_;
   int i;
   for (i=0; i<n_; i++)
   {
      bw = (float) (points_[i].window_*deltaVal/65535.0);
      fprintf(fd, "%g\n", bw);
   }
   fclose(fd);
}


// main function to find K and L
int FAMS::FindKL(int Kmin, int Kmax, int Kjump, int Lmax, int k, float width, float epsilon, int &K, int &L)
{
   bgLog("Find optimal K and L, K=%d:%d:%d, Lmax=%d, k=%d, Err=%.2g\n", Kmin, Kjump, Kmax, Lmax,
      k, epsilon);

   if (hasPoints_==0)
   {
      bgLog("Load points first\n");
      return 1;
   }

   int adaptive = 1;
   int hWidth=0;
   if (width>0)
   {
      adaptive = 0;
      hWidth = (int) (65535.0*(width)/(maxVal_-minVal_));
   }
   k_=k;
   epsilon += 1;

   // select points on which test is run
   SelectMsPoints(FAMS_FKL_NEL*100.0/n_, 0);

   // compute bandwidths for selected points
   ComputeRealBandwidths(hWidth);

   // start finding the correct l for each k
   float scores[FAMS_FKL_TIMES*FAMS_MAX_L];
   int Lcrt, Kcrt;
   
   int nBest;
   int LBest[FAMS_MAX_K];
   int KBest[FAMS_MAX_K];

   int ntimes, is;
   Lcrt = Lmax;
   bgLog(" find valid pairs");
   for (Kcrt = Kmax, nBest=0; Kcrt >= Kmin; Kcrt -= Kjump, nBest++)
   {
      // do iterations for crt K and L = 1...Lcrt
      for (ntimes=0; ntimes<FAMS_FKL_TIMES; ntimes++)
         DoFindKLIteration(Kcrt, Lcrt, &scores[ntimes*Lcrt]);
      
      // get correct for this k
      KBest[nBest]=Kcrt;
      LBest[nBest]=-1;
      for (is=0; (LBest[nBest]==-1) && (is<Lcrt); is++)
      {
         // find max on this column
         for (ntimes=1; ntimes<FAMS_FKL_TIMES; ntimes++)
         {
            if (scores[is]<scores[ntimes*Lcrt+is])
               scores[is] = scores[ntimes*Lcrt+is];
         }
         if (scores[is]<epsilon)
            LBest[nBest]=is+1;
      }
      bgLog(".");

      // update Lcrt to reduce running time!
      if (LBest[nBest]>0)
         Lcrt = LBest[nBest]+2;
   }
   bgLog("done\n");

   //start finding the pair with best running time
   double run_times[FAMS_FKL_TIMES];
   int iBest, i;
   double timeBest=-1;
   bgLog(" select best pair\n");
   for (i=0; i<nBest; i++)
   {
      if (LBest[i]<=0)
         continue;
      for (ntimes=0; ntimes<FAMS_FKL_TIMES; ntimes++)
         run_times[ntimes] = DoFindKLIteration(KBest[i], LBest[i], &scores[ntimes*Lcrt]);
      bgSort(run_times, FAMS_FKL_TIMES);
      if((timeBest==-1) || (timeBest>run_times[FAMS_FKL_TIMES/2]))
      {
         iBest = i;
         timeBest = run_times[FAMS_FKL_TIMES/2];
      }
      bgLog("  K=%d L=%d time: %g\n", KBest[i], LBest[i], run_times[FAMS_FKL_TIMES/2]);
   }
   K = KBest[iBest];
   L = LBest[iBest];

   bgLog("done\n");

   return 0;
}


double FAMS::DoFindKLIteration(int K,int L, float* scores)
{
   K_ = K;
   L_ = L;
   int i, j;

   // Allocate memory for the hash table
   M_ = GetPrime(3*n_*L_/(Bs));
   block *HT = new block[M_];
   int   *hs = new int[M_];
   InitHash(K_+L_);
   
   memset(hs,0,sizeof(int)*M_);

   // Build partitions
   fams_partition *cuts = new fams_partition[L_];
   for(i=0; i<20; i++)
      rand();
   int cut_res[FAMS_MAX_K];
   MakeCuts(cuts);

   //Insert data into partitions
   for(j=0; j<n_; j++)
   { 
      for(i=0; i<L_; i++)
      {
         EvalCutRes(points_[j],cuts[i],cut_res);
         int hjump;
         int m = HashFunction(cut_res,i,K_,M_,&hjump);
         int m2 = HashFunction(&cut_res[1],i,K_-1);
         AddDataToHash(HT,hs,points_[j],m,Bs,M_,i,m2,hjump);
      }
   }
   
   //Compute Scores
   timer_start();
   ComputeScores(HT, hs, cuts, scores);
   double run_time=timer_elapsed(0);

   // clean
   delete [] cuts;
   delete [] hs;
   delete [] HT;

   return run_time;
}

// main function to run FAMS
int FAMS::RunFAMS(int K, int L, int k, double percent, int jump, float width, char* pilot_fn)
{

   bgLog("Running FAMS with K=%d L=%d\n", K, L);
   if (hasPoints_==0)
   {
      bgLog("Load points first\n");
      return 1;
   }
   int i, j;

   int adaptive = 1;
   int hWidth;
   if (width>0)
   {
      adaptive = 0;
      hWidth = (int) (65535.0*(width)/(maxVal_-minVal_));
   }

   K_=K; L_=L; k_=k;
   SelectMsPoints(percent, jump);

   // Allocate memory for the hash table
   M_ = GetPrime(3*n_*L_/(Bs));
   M2_ = GetPrime((int)(nsel_*20*3/Bs2));
   block *HT = new block[M_];
   int   *hs = new int[M_];
   block2 *HT2 = new block2[M2_];
   int   * hs2 = new int[M2_];
   
   memset(hs,0,sizeof(int)*M_);

   // Build partitions
   fams_partition *cuts = new fams_partition[L];
   for(i=0; i<20; i++)
      rand();
   int cut_res[FAMS_MAX_K];
   MakeCuts(cuts);

   InitHash(K_+L_);

   //Insert data into partitions
   for(j=0; j<n_; j++)
   { 
      for(i=0; i<L_; i++)
      {
         EvalCutRes(points_[j],cuts[i],cut_res);
         int hjump;
         int m = HashFunction(cut_res,i,K_,M_,&hjump);
         int m2 = HashFunction(&cut_res[1],i,K_-1);
         AddDataToHash(HT,hs,points_[j],m,Bs,M_,i,m2,hjump);
      }
   }	

   //Compute pilot if necessary
   bgLog(" Run pilot ");
   if(adaptive)
   {
      bgLog("adaptive...");
      ComputePilot(HT,hs,cuts,pilot_fn);
   }
   else
   {
      bgLog("fixed bandwith...");
      unsigned int hwd = (unsigned int)(hWidth*d_);
      for(i=0; i<n_; i++)
      {
         points_[i].window_=hwd;
         points_[i].weightdp2_=1;
      }
   }
   bgLog("done.\n");

   DoFAMS(HT, hs, cuts, HT2, hs2);

   // join modes
   PruneModes(FAMS_PRUNE_WINDOW, FAMS_PRUNE_MINN);

   // clean
   delete [] cuts;
   delete [] hs2;
   delete [] HT2;
   delete [] hs;
   delete [] HT;

   return 0;
}

void usage()
{
   bgLog("usage: K L k file_name input_dir [ [-j jump | -p percent] | -h width \n   -f epsilon Kmin Kjump]"); 
}


int main(int argc,char** argv)
{
   if(argc < 6)
   {
      usage();
      exit(1);
   }
   int no_lsh, find_kl;
   int K, L, k_neigh;
   char *data_file_name, *input_path;
   char fdata_file_name[200];
   K = atoi(argv[1]);
   L = atoi(argv[2]);
   no_lsh = (K<=0) || (L <= 0);
   find_kl = 0;

   float epsilon;
   int Kmin, Kjump, Kmax;
   int Lmax;

   float width = -1;

   k_neigh = atoi(argv[3]);
   data_file_name = argv[4];
   input_path = argv[5];
   sprintf(fdata_file_name,"%s%s.txt",input_path,data_file_name);
   int jump=1;
   double percent=0.0;
   int i;
   if (argc > 6)
   {
      for (i=6; i<argc; i++)
      {
         if (argv[i][0] != '-')
         {
            bgLog("Error in param %s\n", argv[i]);
            usage();
            exit(1);
         }
         switch(argv[i][1])
         {
         case 'j': 
            i++;
            jump = atoi(argv[i]);
            if (jump<1) jump=1;
            break;
         case 'p':
            i++;
            percent = atof(argv[i]);
            if ((percent<0) || (percent>1)) percent = 0;
            break;
         case 'h':
            i++;
            width = (float) atof(argv[i]);
            break;
         case 'f':
            i++;
            epsilon = (float) atof(argv[i++]);
            Kmin = atoi(argv[i++]);
            Kjump = atoi(argv[i]);
            Lmax = L; Kmax = K;
            find_kl=1;
            break;
         default:
            bgLog("Error in param %s\n", argv[i]);
            usage();
            exit(1);
            break;
         }
      }
   }


   // load points
   FAMS cfams(no_lsh);
   if (cfams.LoadPoints(fdata_file_name))
      return 1;

   // find K L (if necessary)
   if (find_kl)
   {
      cfams.FindKL(Kmin, Kmax, Kjump, Lmax, k_neigh, width, epsilon, K, L);
      bgLog("Found K = %d L = %d (write them down)\n", K, L);
      int ch=' ';
      do{
         bgLog("Do you want to run FAMS with this (K=%d,L=%d) pair? (y/n)",K,L);
         ch = getchar();
         if ((ch == 'n') || (ch == 'N'))
            return 0;
      } while((ch != 'y') && (ch != 'Y'));
   }
   sprintf(fdata_file_name, "%spilot_%d_%s.txt", input_path, k_neigh, data_file_name);
   cfams.RunFAMS(K, L, k_neigh, percent, jump, width, fdata_file_name);

   // save the data
   sprintf(fdata_file_name,"%sout_%s.txt",input_path,data_file_name);
   cfams.SaveModes(fdata_file_name);

   // save pruned modes modes
   sprintf(fdata_file_name,"%smodes_%s.txt",input_path, data_file_name);
   cfams.SavePrunedModes(fdata_file_name);


   return 0;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -