ummsearch.c

来自「基于Blas CLapck的.用过的人知道是干啥的」· C语言 代码 · 共 1,470 行 · 第 1/3 页

C
1,470
字号
   }}MULTHEAD *CreateMultNode(int imult){   MULTHEAD *mh;   mh = malloc(sizeof(MULTHEAD));   assert(mh);   mh->imult = imult;   mh->rn = NULL;   mh->next = NULL;   return(mh);}MULTHEAD *GetMultNode(int imult)/* * Finds, and if necessary, creates MULTHEAD with value imult, keeping it * in ascending order */{   MULTHEAD *mh, *mh0;   if (imhead)   {      for (mh=imhead; mh; mh = mh->next) if (mh->imult >= imult) break;      if (mh && imhead == mh && mh->imult != imult)      {         mh = CreateMultNode(imult);         mh->next = imhead;         imhead = mh;      }      else if (!mh || mh->imult != imult)      {         for (mh0=imhead; mh0->next != mh; mh0 = mh0->next);         mh0->next = CreateMultNode(imult);         mh0->next->next = mh;         mh = mh0->next;      }   }   else mh = imhead = CreateMultNode(imult);   return(mh);}void KillMultNode(MULTHEAD *mh){   KillAllRoutNodes(mh->rn);   free(mh);}void KillThisMultNode(MULTHEAD *mhD){   MULTHEAD *mh;   if (mhD == NULL) return;   if (mhD == imhead) imhead = imhead->next;   else   {      for (mh=imhead; mh && mh->next != mhD; mh = mh->next);      assert(mh);      if (mh->next) mh->next = mh->next->next;   }   KillMultNode(mhD);}void KillAllMultNodes(){   MULTHEAD *mh;   while(imhead)   {      mh = imhead->next;      KillMultNode(imhead);      imhead = mh;   }}ROUTNODE *GetRoutNode(int imult, char *rout, int icase, double mflop)/* * Finds, and if necessary creates, the desired RoutNode */{   MULTHEAD *mh;   ROUTNODE *rn;   mh = GetMultNode(imult);   for (rn = mh->rn; rn; rn = rn->next) if (rn->icase == icase) break;   if (!rn)   {      rn = CreateRoutNode(rout, icase, mflop);      if (mh->rn)      {         rn->next = mh->rn;         mh->rn = rn;      }      else mh->rn = rn;   }   return(rn);}void PrintTable(FILE *fpout){   MULTHEAD *mh;   ROUTNODE *rn;   for (mh=imhead; mh; mh = mh->next)   {      fprintf(fpout, "%3d: ", mh->imult);      for (rn=mh->rn; rn; rn = rn->next)         fprintf(fpout, "%3d:%1d,%5.1f ", rn->icase, rn->fixed, rn->mflop);      fprintf(fpout, "\n");   }   fprintf(fpout, "\n");}MULTHEAD *BuildTable(char pre, enum CLEAN_WHICH which, int nb)/* * Builds table of possible cleanup codes, depending on which: * 0 : pMB * 1 : pNB * 2 : pKB */{   ROUTNODE *rn;   int i, n, ID, NB[3];   int iin, io1, io2, iflag, muladd, lat, mu, nu, ku;   char *MCC, *MMFLAGS;   char rout[ROUTLEN], auth[AUTHLEN];   switch(which)   {   case CleanM:      iin = 0;      io1 = 1;      io2 = 2;      break;   case CleanN:      iin = 1;      io1 = 0;      io2 = 2;      break;   case CleanK:      iin = 2;      io1 = 0;      io2 = 1;      break;   }   n = NumUserCases(pre);   for (i=0; i < n; i++)   {      rn = NULL;      ID = GetUserCase(pre, -i, &iflag, NB, NB+1, NB+2, &muladd, &lat,                       &mu, &nu, &ku, rout, auth, &MCC, &MMFLAGS);      if (ATL_MMNoClean(iflag)) continue;      if (NB[io1] < 0 && NB[io1] != -nb) continue;      if (NB[io2] < 0 && NB[io2] != -nb) continue;      if (NB[io1] && (nb % NB[io1])) continue;      if (NB[io2] && (nb % NB[io2])) continue;      if (NB[iin] < 0)      {         if (-NB[iin] < nb) rn = GetRoutNode(-NB[iin], rout, ID, NOTIMED);      }      else if (NB[iin] == 0) rn = GetRoutNode(1, rout, ID, NOTIMED);      else if (NB[iin] < nb) rn = GetRoutNode(NB[iin], rout, ID, NOTIMED);      if (rn) rn->fixed = IsCaseFixed(pre, ID, which);   }   return(imhead);}int MakeMult(int nb, int mul)/* * takes nb, makes it a multiple of mul by reducing */{   return( (nb / mul) * mul );}int GetPNB(char pre, enum CLEAN_WHICH which, int icase, int NB, int imul,           int *pNB)/* * Returns number of pNB that are multiple of imul (max of 3) to be timed; * pNB contains the values to try */{   int i=1, j;   int iflag, NBs[3], muladd, lat, mu, nu, ku;   char fnam[ROUTLEN], *MCC, *MMFLAGS;   pNB[0] = pNB[1] = pNB[2] = 0;   assert(GetUserCase(pre, icase, &iflag, NBs, NBs+1, NBs+2, &muladd, &lat,                      &mu, &nu, &ku, fnam, fnam, &MCC, &MMFLAGS));   if (NBs[which] < 0) pNB[0] = -NBs[which];   else   {      j = pNB[0] = MakeMult(NB-NB/8, imul);      if (!j) pNB[0] = imul;      j = MakeMult(NB/2, imul);      if (j && j != pNB[0])      {         pNB[1] = j;         i = 2;         j = NB/8;         if (NB >= 32) j = Mmax(j, 16);         j = pNB[2] = ((j+imul-1)/imul)*imul;         if (j && j != pNB[1] && j != pNB[0]) i = 3;         else pNB[2] = 0;      }   }   return(i);}#define NO_RESULTS -88.7double GetRes(char *fnam){   FILE *fp;   double mf, mflop[3];   int i;   fp = fopen(fnam, "r");   if (!fp) return(NO_RESULTS);   for (i=0; i != 3; i++)   {      if (fscanf(fp, " %lf", &mflop[i]) != 1)      {         fclose(fp);         remove(fnam);         return(NO_RESULTS);      }   }   fclose(fp);   mf = GetAvg(3, TOLERANCE, mflop);   return(mf);}double ummcase0(   char pre,                  /* type prefix */   int M, int N, int K,       /* problem sizes to time */   int mb, int nb, int kb,    /* 0: variable NB, else fixed cpp macro of NB */   int lda, int ldb, int ldc, /* leading dims */   int muladd, int lat,       /* muladd and latency settings */   int mu, int nu, int ku,    /* unrolling factors */   char *fnam,                /* file name to compile */   char *MCC, char *MMFLAGS,  /* NULL : use defaults, else comp to use */   char *outnam               /* output name */){   char ln[512];   char ch;   int i;   double mf;   if (!FileExists(outnam))   {      if (pre == 'c' || pre == 'z')         i = sprintf(ln, "make cmmucase mmrout=CASES/%s csC=2 ", fnam);      else i = sprintf(ln, "make mmucase mmrout=CASES/%s ", fnam);      if (MCC)      {         ch = (pre == 'c' || pre == 's') ? 'S' : 'D';         i += sprintf(ln+i, "%cMC=\"%s\" %cMCFLAGS=\"%s\" ",                      ch, MCC, ch, MMFLAGS);      }      i += sprintf(ln+i, "casnam=%s ", outnam);      i += sprintf(ln+i, "pre=%c muladd=%d lat=%d M=%d N=%d K=%d mb=%d nb=%d kb=%d mu=%d nu=%d ku=%d lda=%d ldb=%d ldc=%d ",                   pre, muladd, lat, M, N, K, mb, nb, kb, mu, nu, ku,                   lda, ldb, ldc);      i += sprintf(ln+i, "\n");      fprintf(stdout, "%s", ln);      if (system(ln) != 0) return(-1.0);   }   mf = GetRes(outnam);   if (mf == NO_RESULTS) mf = -1.0;   return(mf);}int GetIflag(char pre, int icase){   int iflag, mb, nb, kb, muladd, lat, mu, nu, ku;   char fnam[ROUTLEN], *MCC, *MMFLAGS;   assert(GetUserCase(pre, icase, &iflag, &mb, &nb, &kb, &muladd, &lat,                      &mu, &nu, &ku, fnam, fnam, &MCC, &MMFLAGS));   return(iflag);}double GetCleanCase(char pre, enum CLEAN_WHICH which, int icase, int imul,                    int mb, int nb, int kb){   char cwh[3] = {'M', 'N', 'K'};   char outf[ROUTLEN], fnam[ROUTLEN], *MCC, *MMFLAGS;   int ld=kb, NB[3], NB1[3], NBs[3], nb0;   int iflag, mb1, nb1, kb1, muladd, lat, mu, nu, ku;   assert(GetUserCase(pre, icase, &iflag, NB1, NB1+1, NB1+2, &muladd, &lat,                      &mu, &nu, &ku, fnam, outf, &MCC, &MMFLAGS));   if (ATL_MMNoClean(iflag)) return(-1.0);   NBs[0] = mb;   NBs[1] = nb;   NBs[2] = kb;   nb0 = kb;   if (which == CleanK)   {      nb0 = nb;      if (ATL_MMVarLda(iflag)) ld = 0;   }   NB[0] = NB[1] = NB[2] = nb0;   if (NB1[which]) NB[which] = NBs[which];   else NB[which] = 0;   sprintf(outf, "res/%cup%cB%d_%d_%dx%dx%d", pre, cwh[which], icase, imul,           mb, nb, kb);   return(ummcase0(pre, mb, nb, kb, NB[0], NB[1], NB[2], ld, ld, 0, muladd, lat,                   mu, nu, ku, fnam, MCC, MMFLAGS, outf));}double GetCleanCases0(char pre, enum CLEAN_WHICH which, int nb, int imul,                      int icase, int n, int *pNB){   int i, NB[3];   double mf0, mf=0.0;   NB[0] = NB[1] = NB[2] = nb;   for (i=0; i < n; i++)   {      NB[which] = pNB[i];      mf0 = GetCleanCase(pre, which, icase, imul, NB[0], NB[1], NB[2]);      if (mf0 <= 0.0) return(-1.0); /* reject if it fails to run */      mf += mf0;   }   return(mf / n);}double GetCleanCases(char pre, enum CLEAN_WHICH which, int nb, int imul,                     int icase){   int n, pNB[3];   n = GetPNB(pre, which, icase, nb, imul, pNB);   return(GetCleanCases0(pre, which, nb, imul, icase, n, pNB));}void TimeRouts(char pre, enum CLEAN_WHICH which, int nb, int imul,               ROUTNODE *rn0){   ROUTNODE *rn;   for (rn=rn0; rn; rn = rn->next)      if (rn->mflop == NOTIMED)         rn->mflop = GetCleanCases(pre, which, nb, imul, rn->icase);}void TimeTable(char pre, enum CLEAN_WHICH which, int nb)/* * Times table of cleanup codes, depending on which */{   MULTHEAD *mh;   for (mh=imhead; mh; mh = mh->next)      TimeRouts(pre, which, nb, mh->imult, mh->rn);}ROUTNODE *FindBestRout(ROUTNODE *rn0, double adv)/* * Excludes those of fixed size */{   ROUTNODE *rnB=NULL, *rn;   double mfB=0.0, ad;   for (rn=rn0; rn; rn = rn->next)   {      if (rn->fixed != 2)      {         ad = (rn->fixed == 0)*adv*rn->mflop;         if (rn->mflop+ad > mfB)         {            rnB = rn;            mfB = rn->mflop+ad;         }      }   }   return(rnB);}ROUTNODE *FindBestFixed2(ROUTNODE *rn0){   ROUTNODE *rnB=NULL, *rn;   double mfB=0.0;   for (rn=rn0; rn; rn = rn->next)   {      if (rn->fixed == 2 && rn->mflop > mfB)      {         rnB = rn;         mfB = rn->mflop;      }   }   return(rnB);}void ReduceRouts(char pre, enum CLEAN_WHICH which, MULTHEAD *mh,                 double adv, int nb)/* * reduces routs to best, giving adv advantage to non-fixed routs */{   double mf;   ROUTNODE *rn, *rnN, *rnF, *rnNF=NULL;   int n, NB[3];   rn = FindBestRout(mh->rn, adv);   rnF = FindBestFixed2(mh->rn);   if (rn || rnF) /* some case actually compiled and ran */   {      if (rnF && rn)      {         n = GetPNB(pre, which, rnF->icase, nb, mh->imult, NB);         mf = GetCleanCases0(pre, which, nb, mh->imult, rn->icase, n, NB);         mf += adv*mf;         if (mf < rnF->mflop)         {            rnNF = CreateRoutNode(rnF->rout, rnF->icase, rnF->mflop);            rnNF->fixed = 2;         }      }      else if (rnF)      {         rnNF = CreateRoutNode(rnF->rout, rnF->icase, rnF->mflop);         rnNF->fixed = 2;         KillAllRoutNodes(mh->rn);         mh->rn = rnNF;      }      if (rn)      {         rnN = CreateRoutNode(rn->rout, rn->icase, rn->mflop);         rnN->fixed = rn->fixed;         KillAllRoutNodes(mh->rn);         mh->rn = rnN;         if (rnNF) mh->rn->next = rnNF;         else mh->rn->next = NULL;      }   }   else KillThisMultNode(mh);}void ReduceMults(char pre, enum CLEAN_WHICH which, double adv, int nb)/* * Finds mults that are multiples of each other, and takes best, * giving adv advantage to non-fixed routs, and .5 adv to lower imults */{   MULTHEAD *mh, *mh0, *mh1;   double mf0, mf1;   int imult;   int NB[3];   NB[0] = NB[1] = NB[2] = nb;   for (mh0=imhead; mh0; mh0 = mh0->next)   {      if (mh0->rn->fixed == 2) continue;      for (mh1=mh0, mh=mh0->next; mh; mh = mh->next)      {         if (mh->imult % mh0->imult == 0) /* higher mult is mult of lower */         {            imult = mh->imult;            if (mh->rn->fixed == 2)            {               NB[which] = imult;               mf1 = mh->rn->mflop;            }            else            {               NB[which] = imult <= 4 ? ((nb-imult)/imult)*imult :                                         ((nb-1)/imult)*imult;               mf1 = GetCleanCase(pre, which, mh->rn->icase, imult,                                  NB[0], NB[1], NB[2]);            }            mf0 = GetCleanCase(pre, which, mh0->rn->icase, imult,                               NB[0], NB[1], NB[2]);            if (((mh0->rn->fixed == 0)+0.5)*adv*mf0+mf0 >                (mh->rn->fixed == 0)*adv*mf1+mf1)            {               mh1->next = mh->next;               KillMultNode(mh);               mh = mh1;            }         }         mh1 = mh;      }   }}void ReduceTable(char pre, enum CLEAN_WHICH which, int nb)/*

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?