findce.c

来自「基于Blas CLapck的.用过的人知道是干啥的」· C语言 代码 · 共 474 行

C
474
字号
/* *             Automatically Tuned Linear Algebra Software v3.8.0 *                    (C) Copyright 1997 R. Clint Whaley * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: *   1. Redistributions of source code must retain the above copyright *      notice, this list of conditions and the following disclaimer. *   2. Redistributions in binary form must reproduce the above copyright *      notice, this list of conditions, and the following disclaimer in the *      documentation and/or other materials provided with the distribution. *   3. The name of the ATLAS group or the names of its contributers may *      not be used to endorse or promote products derived from this *      software without specific written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. * */#include "atlas_misc.h"#include "atlas_lvl3.h"#include Mstr(Mjoin(Mjoin(atlas_,PRE),sysinfo.h))#define dumb_seed(iseed_) srand(iseed_)#ifndef RAND_MAX  /* rather dangerous non-ansi workaround */   #define RAND_MAX ((unsigned long)(1<<30))#endif#define dumb_rand() ( 0.5 - ((double)rand())/((double)RAND_MAX) )int FoundCE;int CompCE=0;double time00();void Mjoin(PATL,FindCE_mm)(enum ATLAS_TRANS TA, enum ATLAS_TRANS TB,                           const int M, const int N, const int K,                           const SCALAR alpha, const TYPE *A, const int lda,                           const TYPE *B, const int ldb, const SCALAR beta,                           const TYPE *C, const int ldc);void matgen(int M, int N, TYPE *A, int lda, int seed){   int i, j;#ifdef TCPLX   M *= 2;   lda *= 2;#endif   dumb_seed(seed);   for (j=N; j; j--)   {      for (i=0; i != M; i++) A[i] = dumb_rand();      A += lda;   }}double mmcase(enum ATLAS_TRANS TA, enum ATLAS_TRANS TB, int M, int N, int K,              SCALAR alpha, SCALAR beta, int CE){   char cTA, cTB;   int nL2 = (1.3*L2SIZE)/sizeof(int);   int *iL2=NULL, j=0, i, n;   int lda, ldb, ldc=M;   void *vA, *vB, *vC;   TYPE *A, *B, *C;   double t0, t1;/* * Make sure CE will be different than 0, if CE is not 0 */   if (CE)   {      FoundCE = CE;      CompCE = 1;      Mjoin(PATL,FindCE_mm)(TA, TB, M, N, K, alpha, NULL, lda, NULL, ldb, beta,                            NULL, ldc);      if (CompCE < KB)         return(-2.0);   }   CompCE = 0;/* * Blow off cache flushing if C is already twice as large as L2 */   if (M*N*sizeof(TYPE) >= 2*L2SIZE) nL2 = 0;   if (nL2) iL2 = malloc(nL2 * sizeof(int));   vA = malloc(ATL_Cachelen+ATL_MulBySize(M)*K);   vB = malloc(ATL_Cachelen+ATL_MulBySize(N)*K);   vC = malloc(ATL_Cachelen+ATL_MulBySize(M)*N);   if (!vA || !vB || !vC || (nL2 && !iL2))   {      if (iL2) free(iL2);      if (vA) free(vA);      if (vB) free(vB);      if (vC) free(vC);      return(-1.0);   }   ATL_assert(vA && vB && vC);   if (nL2) ATL_assert(iL2);   A = ATL_AlignPtr(vA);   B = ATL_AlignPtr(vB);   C = ATL_AlignPtr(vC);   if (TA == AtlasNoTrans)   {      lda = M;      matgen(M, K, A, lda, 271*M*K);      cTA = 'N';   }   else   {      lda = K;      matgen(K, M, A, lda, 271*M*K);      if (TA == AtlasTrans) cTA = 'T';      else cTA = 'C';   }   if (TB == AtlasNoTrans)   {      ldb = K;      matgen(K, N, B, ldb, 99876*N*K);      cTB = 'N';   }   else   {      ldb = N;      matgen(N, K, B, ldb, 99876*N*K);      if (TB == AtlasTrans) cTB = 'T';      else cTB = 'C';   }   matgen(M, N, C, M, 81*M*N);/* * invalidate L2 cache */   if (nL2)   {      for (i=0; i != nL2; i++) iL2[i] = 0.0;      for (i=0; i != nL2; i++) j += iL2[i];   }   FoundCE = CE;   t0 = time00();   Mjoin(PATL,FindCE_mm)(TA, TB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);   t1 = time00() - t0;   if (iL2) free(iL2);   free(vA);   free(vB);   free(vC);   return(t1);}void PrintUsage(char *nam){   fprintf(stderr, "USAGE: %s -n <N> -m <M> -k <K> -f <include filename>\n",           nam);   exit(-1);}void GetFlags(int nargs, char *args[], enum ATLAS_TRANS *TA, enum ATLAS_TRANS *TB,              int *M, int *N, int *K, TYPE *alpha, TYPE *beta,              int *CE0, int *CEN, int *incCE, FILE **fpout){   int i, n;   *TB = *TA = AtlasNoTrans;   *TA = AtlasTrans;   #ifdef TREAL      *M = *N = *K = (1600/NB)*NB;   #else      *M = *N = *K = (1200/NB)*NB;   #endif   *CE0 = -1;   *alpha = *beta = 1.0;   #ifdef TCPLX      beta[1] = alpha[1] = 0.0;   #endif   *fpout = NULL;   for (i=1; i < nargs; i++)   {      if (*args[i] != '-') PrintUsage(args[0]);      switch(args[i][1])      {      case 'f':         *fpout = fopen(args[++i], "w");         ATL_assert(fpout);         break;      case 'n':         *N = atoi(args[++i]);         break;      case 'k':         *K = atoi(args[++i]);         break;      case 'm':         *M = atoi(args[++i]);         break;      default :         PrintUsage(args[0]);      }   }   if (*N != -1)   {      if (*M == -1) *M = *N;      if (*K == -1) *K = *N;   }   if (*CE0 == -1)   {      n = ATL_MulBySize(NBNB);      for (i=1; i < n; i <<= 1);      i = i / 1024;      if (!i) i = 1;      *CE0 = i;      *CEN = (2*L2SIZE) / 1024;      if (*CEN > 2048) *CEN = 2048;      *incCE = -2;   }}double tloop(enum ATLAS_TRANS TA, enum ATLAS_TRANS TB, int M, int N, int K,             SCALAR alpha, SCALAR beta, int CE00, int CE0, int CEN, int incCE,             int *CEout, double *mflop){   char cTA, cTB;   int i, CE;   double t1, mf, mmf=0.0, mf0;   if (TA == AtlasNoTrans) cTA = 'N';   else if (TA == AtlasTrans) cTA = 'T';   else cTA = 'C';   if (TB == AtlasNoTrans) cTB = 'N';   else if (TB == AtlasTrans) cTB = 'T';   else cTB = 'C';   i = CE = CE00;   do   {      t1 = mmcase(TA, TB, M, N, K, alpha, beta, i*1024);      if (t1 == -2.0)         mf = 0.0;      else      {         ATL_assert(t1 > 0.0);         #ifdef TREAL            mf = (((2.0*M)*N)*K) / (1000000.0 * t1);         #else            mf = (((8.0*M)*N)*K) / (1000000.0 * t1);         #endif      }      if (mf > mmf)      {         CE = i;         mmf = mf;         if (CE00 == i) mf0 = mf;      }      #ifdef TREAL         fprintf(stdout,                 " %c   %c %7d %7d %7d  %6.2f  %6.2f  %9d %10.3f %9.2f\n",                 cTA, cTB, M, N, K, alpha, beta, i, t1, mf);      #else         fprintf(stdout,            " %c  %c %6d %6d %6d %5.1f %5.1f %5.1f %5.1f %9d %9.3f %8.2f\n",                 cTA, cTB, M, N, K, *alpha, alpha[1], *beta, beta[1],                 i, t1, mf);      #endif      if (i == CE00 && CE0 != CE00) i = CE0;      else if (incCE == -2) i <<= 1;      else i += incCE;   }   while (i <= CEN);   *CEout = CE;   *mflop = mmf;   return(mf0);}void refineCE(enum ATLAS_TRANS TA, enum ATLAS_TRANS TB, int M, int N, int K,              SCALAR alpha, SCALAR beta, int prevCE, int bestCE, int nextCE,              double bestMF, int *CEout, double *mflop){   int newCE, CE;   double mf;/* * See if true max is less than one we have so far */   if (bestCE != prevCE)   {      newCE = bestCE - (bestCE - prevCE)/2;      if (bestCE - newCE <= 1)  /* return if we've found CE within 1K */      {         *CEout = bestCE;         *mflop = bestMF;         return;      }      tloop(TA, TB, M, N, K, alpha, beta, newCE, newCE, newCE, 1, &CE, &mf);      if (mf > bestMF)      {         refineCE(TA, TB, M, N, K, alpha, beta, prevCE, newCE, bestCE, mf,                  CEout, mflop);         return;      }   }/* * See if best CE is greater than what has been tried so far */   if (bestCE != nextCE)   {      newCE = bestCE + (nextCE - bestCE)/2;      if (newCE - bestCE <= 1)  /* return if we've found CE within 1K */      {         *CEout = bestCE;         *mflop = bestMF;         return;      }      tloop(TA, TB, M, N, K, alpha, beta, newCE, newCE, newCE, 1, &CE, &mf);      if (mf >= bestMF)      {         refineCE(TA, TB, M, N, K, alpha, beta, bestCE, newCE, nextCE, mf,                  CEout, mflop);         return;      }   }   *CEout = bestCE;   *mflop = bestMF;   return;}main(int nargs, char *args[]){   enum ATLAS_TRANS TA, TB;   int i, M, N, K, CE0, CEN, incCE, CE=0, nextCE, prevCE;   #ifdef TREAL      TYPE alpha, beta;   #else      TYPE alpha[2], beta[2];   #endif   char *sp;   double mf, mf0;   FILE *fpout;   GetFlags(nargs, args, &TA, &TB, &M, &N, &K, SADD alpha, SADD beta, &CE0, &CEN, &incCE,            &fpout);   if (M == -1)   {/* *    Blocking for very large caches problematic due to line conflicts, so no *    use going above 1MB or so . . . */      K = 1024*1024;      if (L2SIZE < K) K = L2SIZE;      K /= ATL_sizeof;      K = 1.15*((K-NBNB)/(2.0*NB));      K = ((K+NB-1)/NB)*NB;   }   #ifdef TREAL      fprintf(stdout, "TA  TB       M       N       K   alpha    beta  CacheEdge       TIME    MFLOPS\n");      fprintf(stdout, "==  ==  ======  ======  ======  ======  ======  =========  =========  ========\n\n");   #else      fprintf(stdout, "TA TB      M      N      K    alpha       beta     CacheEdge      TIME   MFLOPS\n");      fprintf(stdout, "== == ====== ====== ====== ===== ===== ===== ===== ========= ========= ========\n\n");   #endif/* * Determine rough flop rate, so we can see how big a problem to do */   if (M == -1)   {      #ifdef ATL_nkflop         mf = ATL_nkflop * 1000.0;      #else         #ifdef TREAL            mf = mmcase(TA, TB, 450, 450, 450, alpha, beta, 0);            mf = (2.0*450.0*450.0*450.0) / mf;         #else            mf = mmcase(TA, TB, 200, 200, 200, alpha, beta, 0);            mf = (8.0*200.0*200.0*200.0) / mf;         #endif      #endif      mf = (mf*4.0) / K;      for (i=8*NB; i*i < mf; i += NB);      M = N = i;   }/* * preload instructions, and ensure we can allocate the memory */   do   {      mf0 = mmcase(TA, TB, M, N, K, alpha, beta, 0);      if (mf0 <= 0.0)      {         if (K > (Mmax(M,N)<<3)) K >>= 1;         if (M > N) M >>= 1;         else N >>= 1;      }   }   while(mf0 <= 0.0);   mf0 = tloop(TA, TB, M, N, K, alpha, beta, 0, CE0, CEN, incCE, &CE, &mf);/* * If best CacheEdge not 3% better than no cachedge, * its probably clock resolution * Go ahead and accept any CacheEdge that gets the same performance, though, * since it will use less memory. */   if (mf >= mf0)   {      fprintf(stdout, "\nInitial CE=%dKB, mflop=%.2f\n\n", CE, mf);/* *    If CacheEdge not already at extremum, refine it */      if (CE != 0 && CE != CEN && CE != CE0)      {         if (incCE == -2)         {            prevCE = CE / 2;            nextCE = CE * 2;         }         else         {            prevCE = CE - incCE;            nextCE = CE+incCE;         }         if (prevCE < 0) prevCE = 0;         if (nextCE > CEN) nextCE = CEN;         refineCE(TA, TB, M,N,K, alpha, beta, prevCE, CE, nextCE, mf, &CE, &mf);      }      fprintf(stdout, "\nBest CE=%dKB, mflop=%.2f\n", CE, mf);   }   else   {      fprintf(stdout,              "Best CE=%dKB, mflop=%.2f, might as well set to 4MB (%.2f)\n",              CE, mf, mf0);      CE = 4096;   }   if (fpout)   {#ifdef ATL_JITcp      if (CE)      {         FoundCE = CE;         CompCE = 1;         Mjoin(PATL,FindCE_mm)(TA, TB, M, N, K, alpha, NULL, 1, NULL, 1, beta,                               NULL, 1);      }      else CompCE = 0;   #ifdef DCPLX      sp = "ZD";   #else      sp = "CS";   #endif      fprintf(fpout, "#ifndef ATLAS_%sNKB_H\n", sp);      fprintf(fpout, "   #define ATLAS_%sNKB_H\n", sp);      fprintf(fpout, "   #define ATL_%sNKB %d\n", sp, CompCE/KB);      fprintf(fpout, "#endif\n");#else      fprintf(fpout, "#ifndef ATLAS_CACHEEDGE_H\n");      fprintf(fpout, "   #define ATLAS_CACHEEDGE_H\n");      /* if (mf > 1.04*mf0) */         fprintf(fpout, "   #define CacheEdge %d\n", CE*1024);      fprintf(fpout, "#endif\n");#endif      fclose(fpout);   }   exit(0);}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?