emit_mm.c

来自「基于Blas CLapck的.用过的人知道是干啥的」· C语言 代码 · 共 1,670 行 · 第 1/5 页

C
1,670
字号
/* *             Automatically Tuned Linear Algebra Software v3.8.0 *                    (C) Copyright 1997 R. Clint Whaley * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: *   1. Redistributions of source code must retain the above copyright *      notice, this list of conditions and the following disclaimer. *   2. Redistributions in binary form must reproduce the above copyright *      notice, this list of conditions, and the following disclaimer in the *      documentation and/or other materials provided with the distribution. *   3. The name of the ATLAS group or the names of its contributers may *      not be used to endorse or promote products derived from this *      software without specific written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. * */#include <stdio.h>#include <stdlib.h>#include <string.h>#include <assert.h>#include "atlas_prefetch.h"static int LD_AT_BOTTOM=0;  /* load $C$ after K-loop, rather than before? */static int TEMP_TYPE=0; /* type of temp regs: 0-TYPE, 1-double, 2-long double */typedef struct CleanNode CLEANNODE;struct CleanNode{   CLEANNODE *next;   int *NBs;   int imult, icase, fixed, nb, ncomps;   char rout[256], CC[256], CCFLAGS[512];};#if defined(ATL_SSE1)   #define ICC_IS_RETARDED#endif#ifndef MAX_CASG_KU   #define MAX_CASG_KU 2#endif#define Mmin(x, y) ( (x) > (y) ? (y) : (x) )#define Mmax(x, y) ( (x) > (y) ? (x) : (y) )#define ATL_Mlcm(x_, y_, ans_) \{ \   if ((x_) >= (y_)) \   { \      (ans_) = (x_); \      while( ((ans_)/(y_))*(y_) != (ans_) ) (ans_) += (x_); \   } \   else \   { \      (ans_) = (y_); \      while( ((ans_)/(x_))*(x_) != (ans_) ) (ans_) += (y_); \   } \}#define Mlowcase(C) ( ((C) > 64 && (C) < 91) ? (C) | 32 : (C) )char PRE='d';char *TYPE="double";int MUL=1;enum CW {CleanM=0, CleanN=1, CleanK=2, CleanNot=3};enum ATLAS_LOOP_ORDER {AtlasIJK=0, AtlasJIK=1};enum ATLAS_TRANS {AtlasNoTrans=111, AtlasTrans=112, AtlasConjTrans=113};typedef void (*KLOOPFUNC)(FILE *, char*, enum ATLAS_LOOP_ORDER,                          enum ATLAS_TRANS, enum ATLAS_TRANS, int, int, int,                          int, int, int, int, char*, char*, char*, char*, char*,                          int, int, int, int, int, int, int, int, int,                          char*, char*);#define SAFE_ALPHA -3void PrintC99Defines(FILE *fpout, char *spc){   fprintf(fpout, "%s#ifndef ATL_RESTRICT\n", spc);   fprintf(fpout,      "%s#if defined(__STDC_VERSION__) && (__STDC_VERSION__/100 >= 1999)\n",           spc);   fprintf(fpout, "%s   #define ATL_RESTRICT restrict\n", spc);   fprintf(fpout, "%s#else\n%s   #define ATL_RESTRICT\n%s#endif\n",           spc, spc, spc);   fprintf(fpout, "%s#endif\n", spc);}int GetPower2(int n){   int pwr2, i;   if (n == 1) return(0);   for (pwr2=0, i=1; i < n; i <<= 1, pwr2++);   if (i != n) pwr2 = 0;   return(pwr2);}#define ShiftThresh 2char *GetDiv(int N, char *inc){   static char ln[256];   int pwr2 = GetPower2(N);   if (N == 1) sprintf(ln, "%s", inc);   else if (pwr2) sprintf(ln, "((%s) >> %d)", inc, pwr2);   else sprintf(ln, "((%s) / %d)", inc, N);   return(ln);}char *GetInc(int N, char *inc){   static char ln0[256];   char ln[256];   char *p=ln;   int i, n=N, iPLUS=0;   if (n == 0)   {      ln[0] = '0';      ln[1] = '\0';   }   while(n > 1)   {      for (i=0; n >= (1<<i); i++);      if ( (1 << i) > n) i--;      if (iPLUS++) *p++ = '+';      sprintf(p, "((%s) << %d)", inc, i);      p += strlen(p);      n -= (1 << i);   }   if (n == 1)   {      if (iPLUS++) *p++ = '+';      sprintf(p, "%s", inc);   }   if (iPLUS > ShiftThresh) sprintf(ln0, "(%d*(%s))", N, inc);   else if (iPLUS) sprintf(ln0, "(%s)", ln);   else sprintf(ln0, "%s", ln);   return(ln0);}void emit_uhead(FILE *fp, char pre, enum CW which, int mb, int nb, int kb,                int lda, int ldb, int ldc, int beta)/* * if which != CleanNot, ldc is not used, lda is imult ldb is fixed, * and ldc is NBs[j] */{   char cbet;   char cwh[3] = {'M', 'N', 'K'};   int i;   if (beta == 1) cbet = '1';   else if (beta == 0) cbet = '0';   else cbet = 'X';   if (which == CleanNot)   {      fprintf(fp, "#define ATL_USERMM ATL_%cJIK%dx%dx%dTN%dx%dx%d_a1_b%c\n",              pre, mb, nb, kb, lda, ldb, ldc, cbet);      fprintf(fp, "#define ATL_USERMM_b1 ATL_%cJIK%dx%dx%dTN%dx%dx%d_a1_b1\n",              pre, mb, nb, kb, lda, ldb, ldc);      fprintf(fp, "#define ATL_USERMM_b0 ATL_%cJIK%dx%dx%dTN%dx%dx%d_a1_b0\n",              pre, mb, nb, kb, lda, ldb, ldc);      fprintf(fp, "#define ATL_USERMM_bX ATL_%cJIK%dx%dx%dTN%dx%dx%d_a1_bX\n",              pre, mb, nb, kb, lda, ldb, ldc);   }   else   {      fprintf(fp, "#define ATL_USERMM ATL_%cup%cBmm%d_%d_%d_b%c\n",              pre, cwh[which], ldc, lda, ldb, cbet);      fprintf(fp, "#define ATL_USERMM_b1 ATL_%cup%cBmm%d_%d_%d_b1\n",              pre, cwh[which], ldc, lda, ldb);      fprintf(fp, "#define ATL_USERMM_b0 ATL_%cup%cBmm%d_%d_%d_b0\n",              pre, cwh[which], ldc, lda, ldb);      fprintf(fp, "#define ATL_USERMM_bX ATL_%cup%cBmm%d_%d_%d_bX\n",              pre, cwh[which], ldc, lda, ldb);   }   fprintf(fp, "#define BETA%c\n", cbet);   if (pre == 's') fprintf(fp, "#define SREAL\n");   else if (pre == 'd') fprintf(fp, "#define DREAL\n");   else if (pre == 'c') fprintf(fp, "#define SCPLX\n");   else if (pre == 'z') fprintf(fp, "#define DCPLX\n");   fprintf(fp, "\n#define MB %d\n#define NB %d\n#define KB %d\n", mb, nb, kb);   fprintf(fp, "\n#define MBMB %d\n#define NBNB %d\n#define KBKB %d\n",           mb*mb, nb*nb, kb*kb);   for (i=2; i <= 8; i++)      fprintf(fp, "\n#define MB%d %d\n#define NB%d %d\n#define KB%d %d\n\n",              i, i*mb, i, i*nb, i, i*kb);}void emit_head(int NC, FILE *fpout, char pre, int nb, int muladd, int lat,               int mu, int nu, int ku){   int i, pow2nb;   char nam[128];   char upr;   fprintf(fpout, "#ifndef %cMM_H\n", toupper(pre));   fprintf(fpout, "   #define %cMM_H\n\n", toupper(pre));   if (muladd) fprintf(fpout, "   #define ATL_mmMULADD\n");   else fprintf(fpout, "   #define ATL_mmNOMULADD\n");   fprintf(fpout, "   #define ATL_mmLAT %d\n", lat);   fprintf(fpout, "   #define ATL_mmMU  %d\n", mu);   fprintf(fpout, "   #define ATL_mmNU  %d\n", nu);   fprintf(fpout, "   #define ATL_mmKU  %d\n", ku);   fprintf(fpout, "   #define MB %d\n", nb);   fprintf(fpout, "   #define NB %d\n", nb);   fprintf(fpout, "   #define KB %d\n", nb);   fprintf(fpout, "   #define NBNB %d\n", nb*nb);   fprintf(fpout, "   #define MBNB %d\n", nb*nb);   fprintf(fpout, "   #define MBKB %d\n", nb*nb);   fprintf(fpout, "   #define NBKB %d\n", nb*nb);   fprintf(fpout, "   #define NB2 %d\n", 2*nb);   fprintf(fpout, "   #define NBNB2 %d\n\n", 2*nb*nb);   for (i=1,pow2nb=0; i < nb; i <<= 1, pow2nb++);   if (i == nb)   {      fprintf(fpout, "   #define ATL_MulByNB(N_) ((N_) << %d)\n", pow2nb);      fprintf(fpout, "   #define ATL_DivByNB(N_) ((N_) >> %d)\n", pow2nb);      fprintf(fpout, "   #define ATL_MulByNBNB(N_) ((N_) << %d)\n", 2*pow2nb);   }   else   {      fprintf(fpout, "   #define ATL_MulByNB(N_) ((N_) * %d)\n", nb);      fprintf(fpout, "   #define ATL_DivByNB(N_) ((N_) / %d)\n", nb);      fprintf(fpout, "   #define ATL_MulByNBNB(N_) ((N_) * %d)\n", nb*nb);   }   if (!NC)   {      sprintf(nam, "ATL_%cJIK%dx%dx%dTN%dx%dx%d", pre, nb, nb, nb, nb, nb, 0);      if (pre == 'd' || pre == 's')      {         fprintf(fpout, "   #define NBmm %s_a1_b1\n", nam);         fprintf(fpout, "   #define NBmm_b1 %s_a1_b1\n", nam);         fprintf(fpout, "   #define NBmm_b0 %s_a1_b0\n", nam);         fprintf(fpout, "   #define NBmm_bX %s_a1_bX\n", nam);      }      else      {         fprintf(fpout, "void %s_a1_b0(const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc);\n", nam);         fprintf(fpout, "void %s_a1_b1(const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc);\n", nam);         fprintf(fpout, "void %s_a1_bX(const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc);\n\n", nam);         fprintf(fpout, "   #define NBmm_b1(m_, n_, k_, al_, A_, lda_, B_, ldb_, be_, C_, ldc_) \\\n");         fprintf(fpout, "{ \\\n");            fprintf(fpout, "   %s_a1_bX(m_, n_, k_, al_, (A_), lda_, (B_), ldb_, ATL_rnone, C_, ldc_); \\\n", nam);            fprintf(fpout, "   %s_a1_b1(m_, n_, k_, al_, (A_), lda_, (B_)+NBNB, ldb_, ATL_rone, (C_)+1, ldc_); \\\n", nam);            fprintf(fpout, "   %s_a1_bX(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_)+NBNB, ldb_, ATL_rnone, C_, ldc_); \\\n", nam);            fprintf(fpout, "   %s_a1_b1(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_), ldb_, ATL_rone, (C_)+1, ldc_); \\\n", nam);         fprintf(fpout, "   }\n");         fprintf(fpout, "   #define NBmm_b0(m_, n_, k_, al_, A_, lda_, B_, ldb_, be_, C_, ldc_) \\\n");         fprintf(fpout, "{ \\\n");         fprintf(fpout, "   %s_a1_b0(m_, n_, k_, al_, (A_), lda_, (B_), ldb_, ATL_rzero, C_, ldc_); \\\n", nam);         fprintf(fpout, "   %s_a1_b0(m_, n_, k_, al_, (A_), lda_, (B_)+NBNB, ldb_, ATL_rzero, (C_)+1, ldc_); \\\n", nam);         fprintf(fpout, "   %s_a1_bX(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_)+NBNB, ldb_, ATL_rnone, C_, ldc_); \\\n", nam);         fprintf(fpout, "   %s_a1_b1(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_), ldb_, ATL_rone, (C_)+1, ldc_); \\\n", nam);         fprintf(fpout, "   }\n");         fprintf(fpout, "   #define NBmm_bX(m_, n_, k_, al_, A_, lda_, B_, ldb_, be_, C_, ldc_) \\\n");         fprintf(fpout, "{ \\\n");         fprintf(fpout, "   %s_a1_bX(m_, n_, k_, al_, (A_), lda_, (B_), ldb_, -(be_), C_, ldc_); \\\n", nam);         fprintf(fpout, "   %s_a1_bX(m_, n_, k_, al_, (A_), lda_, (B_)+NBNB, ldb_, be_, (C_)+1, ldc_); \\\n", nam);         fprintf(fpout, "   %s_a1_bX(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_)+NBNB, ldb_, ATL_rnone, C_, ldc_); \\\n", nam);         fprintf(fpout, "   %s_a1_b1(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_), ldb_, ATL_rone, (C_)+1, ldc_); \\\n", nam);         fprintf(fpout, "   }\n");         if (pre == 'z') upr = 'd';         else upr = 's';         sprintf(nam, "ATL_%cJIK%dx%dx%dTN%dx%dx%d", upr, nb, nb, nb, nb, nb,0);         fprintf(fpout, "   #define rNBmm_b1 %s_a1_b1\n", nam);         fprintf(fpout, "   #define rNBmm_b0 %s_a1_b0\n", nam);         fprintf(fpout, "   #define rNBmm_bX %s_a1_bX\n", nam);      }   }   fprintf(fpout, "\n#endif\n");}int GetGoodLat(int muladd, int mu, int nu, int ku, int lat){   int mul=mu*nu*ku, slat, blat;   if (muladd) return(lat);   for(slat=lat; mul % slat; slat--);   for(blat=lat; mul % blat && blat < mul; blat++);   if (blat-lat > lat-slat || mul%blat) return(slat);   else return(blat);}void opfetch(FILE *fpout,          /* where to print */             char *spc,            /* indentation string */             int ifetch, /* number of elts to fetch from memory into regs */             char *rA,  /* name for register holding elt of inner matrix */             char *rB,  /* name for register holding elt of outer matrix */             char *pA,  /* name for pointer to elt of inner matrix */             char *pB,  /* name for pointer to elt of outer matrix */             int mu,    /* register blocking for inner matrix */             int nu,    /* register blocking for outer matrix */             int offA,  /* offset to first elt of this block of inner matrix */             int offB,  /* offset to first elt of this block of outer matrix */             int lda,   /* row stride; if 0, row stride is arbitrary */             int ldb,   /* row stride; if 0, row stride is arbitrary */             int mulA,  /* col stride; 1: real, 2: cplx */             int mulB,  /* col stride; 1: real, 2: cplx */             int rowA,  /* if 0 : fetch within column, else fetch within row */             int rowB,  /* if 0 : fetch within column, else fetch within row */             int *ia0,  /* elt of inner matrix to be fetched */             int *ib0)  /* elt of inner matrix to be fetched *//* * This routine is used to generate memory-to-register fetches * A is inner matrix, fetched first;  B is outer matrix, fetched last * Assumes each matrix has either a pointer for each column of accessed, labeled * <pA><col#>, or the number of rows is a constant (ldx), and only one pointer, * <pA>0. */{   int ia = *ia0, ib = *ib0, nf = 0;   if (ia >= mu && ib >= nu) return;   if (ia == 0 && ib == 0) /* initial fetch, always get 2 */   {      assert(ifetch >= 2);

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?