emit_mm.c
来自「基于Blas CLapck的.用过的人知道是干啥的」· C语言 代码 · 共 1,670 行 · 第 1/5 页
C
1,670 行
/* * Automatically Tuned Linear Algebra Software v3.8.0 * (C) Copyright 1997 R. Clint Whaley * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions, and the following disclaimer in the * documentation and/or other materials provided with the distribution. * 3. The name of the ATLAS group or the names of its contributers may * not be used to endorse or promote products derived from this * software without specific written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. * */#include <stdio.h>#include <stdlib.h>#include <string.h>#include <assert.h>#include "atlas_prefetch.h"static int LD_AT_BOTTOM=0; /* load $C$ after K-loop, rather than before? */static int TEMP_TYPE=0; /* type of temp regs: 0-TYPE, 1-double, 2-long double */typedef struct CleanNode CLEANNODE;struct CleanNode{ CLEANNODE *next; int *NBs; int imult, icase, fixed, nb, ncomps; char rout[256], CC[256], CCFLAGS[512];};#if defined(ATL_SSE1) #define ICC_IS_RETARDED#endif#ifndef MAX_CASG_KU #define MAX_CASG_KU 2#endif#define Mmin(x, y) ( (x) > (y) ? (y) : (x) )#define Mmax(x, y) ( (x) > (y) ? (x) : (y) )#define ATL_Mlcm(x_, y_, ans_) \{ \ if ((x_) >= (y_)) \ { \ (ans_) = (x_); \ while( ((ans_)/(y_))*(y_) != (ans_) ) (ans_) += (x_); \ } \ else \ { \ (ans_) = (y_); \ while( ((ans_)/(x_))*(x_) != (ans_) ) (ans_) += (y_); \ } \}#define Mlowcase(C) ( ((C) > 64 && (C) < 91) ? (C) | 32 : (C) )char PRE='d';char *TYPE="double";int MUL=1;enum CW {CleanM=0, CleanN=1, CleanK=2, CleanNot=3};enum ATLAS_LOOP_ORDER {AtlasIJK=0, AtlasJIK=1};enum ATLAS_TRANS {AtlasNoTrans=111, AtlasTrans=112, AtlasConjTrans=113};typedef void (*KLOOPFUNC)(FILE *, char*, enum ATLAS_LOOP_ORDER, enum ATLAS_TRANS, enum ATLAS_TRANS, int, int, int, int, int, int, int, char*, char*, char*, char*, char*, int, int, int, int, int, int, int, int, int, char*, char*);#define SAFE_ALPHA -3void PrintC99Defines(FILE *fpout, char *spc){ fprintf(fpout, "%s#ifndef ATL_RESTRICT\n", spc); fprintf(fpout, "%s#if defined(__STDC_VERSION__) && (__STDC_VERSION__/100 >= 1999)\n", spc); fprintf(fpout, "%s #define ATL_RESTRICT restrict\n", spc); fprintf(fpout, "%s#else\n%s #define ATL_RESTRICT\n%s#endif\n", spc, spc, spc); fprintf(fpout, "%s#endif\n", spc);}int GetPower2(int n){ int pwr2, i; if (n == 1) return(0); for (pwr2=0, i=1; i < n; i <<= 1, pwr2++); if (i != n) pwr2 = 0; return(pwr2);}#define ShiftThresh 2char *GetDiv(int N, char *inc){ static char ln[256]; int pwr2 = GetPower2(N); if (N == 1) sprintf(ln, "%s", inc); else if (pwr2) sprintf(ln, "((%s) >> %d)", inc, pwr2); else sprintf(ln, "((%s) / %d)", inc, N); return(ln);}char *GetInc(int N, char *inc){ static char ln0[256]; char ln[256]; char *p=ln; int i, n=N, iPLUS=0; if (n == 0) { ln[0] = '0'; ln[1] = '\0'; } while(n > 1) { for (i=0; n >= (1<<i); i++); if ( (1 << i) > n) i--; if (iPLUS++) *p++ = '+'; sprintf(p, "((%s) << %d)", inc, i); p += strlen(p); n -= (1 << i); } if (n == 1) { if (iPLUS++) *p++ = '+'; sprintf(p, "%s", inc); } if (iPLUS > ShiftThresh) sprintf(ln0, "(%d*(%s))", N, inc); else if (iPLUS) sprintf(ln0, "(%s)", ln); else sprintf(ln0, "%s", ln); return(ln0);}void emit_uhead(FILE *fp, char pre, enum CW which, int mb, int nb, int kb, int lda, int ldb, int ldc, int beta)/* * if which != CleanNot, ldc is not used, lda is imult ldb is fixed, * and ldc is NBs[j] */{ char cbet; char cwh[3] = {'M', 'N', 'K'}; int i; if (beta == 1) cbet = '1'; else if (beta == 0) cbet = '0'; else cbet = 'X'; if (which == CleanNot) { fprintf(fp, "#define ATL_USERMM ATL_%cJIK%dx%dx%dTN%dx%dx%d_a1_b%c\n", pre, mb, nb, kb, lda, ldb, ldc, cbet); fprintf(fp, "#define ATL_USERMM_b1 ATL_%cJIK%dx%dx%dTN%dx%dx%d_a1_b1\n", pre, mb, nb, kb, lda, ldb, ldc); fprintf(fp, "#define ATL_USERMM_b0 ATL_%cJIK%dx%dx%dTN%dx%dx%d_a1_b0\n", pre, mb, nb, kb, lda, ldb, ldc); fprintf(fp, "#define ATL_USERMM_bX ATL_%cJIK%dx%dx%dTN%dx%dx%d_a1_bX\n", pre, mb, nb, kb, lda, ldb, ldc); } else { fprintf(fp, "#define ATL_USERMM ATL_%cup%cBmm%d_%d_%d_b%c\n", pre, cwh[which], ldc, lda, ldb, cbet); fprintf(fp, "#define ATL_USERMM_b1 ATL_%cup%cBmm%d_%d_%d_b1\n", pre, cwh[which], ldc, lda, ldb); fprintf(fp, "#define ATL_USERMM_b0 ATL_%cup%cBmm%d_%d_%d_b0\n", pre, cwh[which], ldc, lda, ldb); fprintf(fp, "#define ATL_USERMM_bX ATL_%cup%cBmm%d_%d_%d_bX\n", pre, cwh[which], ldc, lda, ldb); } fprintf(fp, "#define BETA%c\n", cbet); if (pre == 's') fprintf(fp, "#define SREAL\n"); else if (pre == 'd') fprintf(fp, "#define DREAL\n"); else if (pre == 'c') fprintf(fp, "#define SCPLX\n"); else if (pre == 'z') fprintf(fp, "#define DCPLX\n"); fprintf(fp, "\n#define MB %d\n#define NB %d\n#define KB %d\n", mb, nb, kb); fprintf(fp, "\n#define MBMB %d\n#define NBNB %d\n#define KBKB %d\n", mb*mb, nb*nb, kb*kb); for (i=2; i <= 8; i++) fprintf(fp, "\n#define MB%d %d\n#define NB%d %d\n#define KB%d %d\n\n", i, i*mb, i, i*nb, i, i*kb);}void emit_head(int NC, FILE *fpout, char pre, int nb, int muladd, int lat, int mu, int nu, int ku){ int i, pow2nb; char nam[128]; char upr; fprintf(fpout, "#ifndef %cMM_H\n", toupper(pre)); fprintf(fpout, " #define %cMM_H\n\n", toupper(pre)); if (muladd) fprintf(fpout, " #define ATL_mmMULADD\n"); else fprintf(fpout, " #define ATL_mmNOMULADD\n"); fprintf(fpout, " #define ATL_mmLAT %d\n", lat); fprintf(fpout, " #define ATL_mmMU %d\n", mu); fprintf(fpout, " #define ATL_mmNU %d\n", nu); fprintf(fpout, " #define ATL_mmKU %d\n", ku); fprintf(fpout, " #define MB %d\n", nb); fprintf(fpout, " #define NB %d\n", nb); fprintf(fpout, " #define KB %d\n", nb); fprintf(fpout, " #define NBNB %d\n", nb*nb); fprintf(fpout, " #define MBNB %d\n", nb*nb); fprintf(fpout, " #define MBKB %d\n", nb*nb); fprintf(fpout, " #define NBKB %d\n", nb*nb); fprintf(fpout, " #define NB2 %d\n", 2*nb); fprintf(fpout, " #define NBNB2 %d\n\n", 2*nb*nb); for (i=1,pow2nb=0; i < nb; i <<= 1, pow2nb++); if (i == nb) { fprintf(fpout, " #define ATL_MulByNB(N_) ((N_) << %d)\n", pow2nb); fprintf(fpout, " #define ATL_DivByNB(N_) ((N_) >> %d)\n", pow2nb); fprintf(fpout, " #define ATL_MulByNBNB(N_) ((N_) << %d)\n", 2*pow2nb); } else { fprintf(fpout, " #define ATL_MulByNB(N_) ((N_) * %d)\n", nb); fprintf(fpout, " #define ATL_DivByNB(N_) ((N_) / %d)\n", nb); fprintf(fpout, " #define ATL_MulByNBNB(N_) ((N_) * %d)\n", nb*nb); } if (!NC) { sprintf(nam, "ATL_%cJIK%dx%dx%dTN%dx%dx%d", pre, nb, nb, nb, nb, nb, 0); if (pre == 'd' || pre == 's') { fprintf(fpout, " #define NBmm %s_a1_b1\n", nam); fprintf(fpout, " #define NBmm_b1 %s_a1_b1\n", nam); fprintf(fpout, " #define NBmm_b0 %s_a1_b0\n", nam); fprintf(fpout, " #define NBmm_bX %s_a1_bX\n", nam); } else { fprintf(fpout, "void %s_a1_b0(const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc);\n", nam); fprintf(fpout, "void %s_a1_b1(const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc);\n", nam); fprintf(fpout, "void %s_a1_bX(const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc);\n\n", nam); fprintf(fpout, " #define NBmm_b1(m_, n_, k_, al_, A_, lda_, B_, ldb_, be_, C_, ldc_) \\\n"); fprintf(fpout, "{ \\\n"); fprintf(fpout, " %s_a1_bX(m_, n_, k_, al_, (A_), lda_, (B_), ldb_, ATL_rnone, C_, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_b1(m_, n_, k_, al_, (A_), lda_, (B_)+NBNB, ldb_, ATL_rone, (C_)+1, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_bX(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_)+NBNB, ldb_, ATL_rnone, C_, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_b1(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_), ldb_, ATL_rone, (C_)+1, ldc_); \\\n", nam); fprintf(fpout, " }\n"); fprintf(fpout, " #define NBmm_b0(m_, n_, k_, al_, A_, lda_, B_, ldb_, be_, C_, ldc_) \\\n"); fprintf(fpout, "{ \\\n"); fprintf(fpout, " %s_a1_b0(m_, n_, k_, al_, (A_), lda_, (B_), ldb_, ATL_rzero, C_, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_b0(m_, n_, k_, al_, (A_), lda_, (B_)+NBNB, ldb_, ATL_rzero, (C_)+1, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_bX(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_)+NBNB, ldb_, ATL_rnone, C_, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_b1(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_), ldb_, ATL_rone, (C_)+1, ldc_); \\\n", nam); fprintf(fpout, " }\n"); fprintf(fpout, " #define NBmm_bX(m_, n_, k_, al_, A_, lda_, B_, ldb_, be_, C_, ldc_) \\\n"); fprintf(fpout, "{ \\\n"); fprintf(fpout, " %s_a1_bX(m_, n_, k_, al_, (A_), lda_, (B_), ldb_, -(be_), C_, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_bX(m_, n_, k_, al_, (A_), lda_, (B_)+NBNB, ldb_, be_, (C_)+1, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_bX(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_)+NBNB, ldb_, ATL_rnone, C_, ldc_); \\\n", nam); fprintf(fpout, " %s_a1_b1(m_, n_, k_, al_, (A_)+NBNB, lda_, (B_), ldb_, ATL_rone, (C_)+1, ldc_); \\\n", nam); fprintf(fpout, " }\n"); if (pre == 'z') upr = 'd'; else upr = 's'; sprintf(nam, "ATL_%cJIK%dx%dx%dTN%dx%dx%d", upr, nb, nb, nb, nb, nb,0); fprintf(fpout, " #define rNBmm_b1 %s_a1_b1\n", nam); fprintf(fpout, " #define rNBmm_b0 %s_a1_b0\n", nam); fprintf(fpout, " #define rNBmm_bX %s_a1_bX\n", nam); } } fprintf(fpout, "\n#endif\n");}int GetGoodLat(int muladd, int mu, int nu, int ku, int lat){ int mul=mu*nu*ku, slat, blat; if (muladd) return(lat); for(slat=lat; mul % slat; slat--); for(blat=lat; mul % blat && blat < mul; blat++); if (blat-lat > lat-slat || mul%blat) return(slat); else return(blat);}void opfetch(FILE *fpout, /* where to print */ char *spc, /* indentation string */ int ifetch, /* number of elts to fetch from memory into regs */ char *rA, /* name for register holding elt of inner matrix */ char *rB, /* name for register holding elt of outer matrix */ char *pA, /* name for pointer to elt of inner matrix */ char *pB, /* name for pointer to elt of outer matrix */ int mu, /* register blocking for inner matrix */ int nu, /* register blocking for outer matrix */ int offA, /* offset to first elt of this block of inner matrix */ int offB, /* offset to first elt of this block of outer matrix */ int lda, /* row stride; if 0, row stride is arbitrary */ int ldb, /* row stride; if 0, row stride is arbitrary */ int mulA, /* col stride; 1: real, 2: cplx */ int mulB, /* col stride; 1: real, 2: cplx */ int rowA, /* if 0 : fetch within column, else fetch within row */ int rowB, /* if 0 : fetch within column, else fetch within row */ int *ia0, /* elt of inner matrix to be fetched */ int *ib0) /* elt of inner matrix to be fetched *//* * This routine is used to generate memory-to-register fetches * A is inner matrix, fetched first; B is outer matrix, fetched last * Assumes each matrix has either a pointer for each column of accessed, labeled * <pA><col#>, or the number of rows is a constant (ldx), and only one pointer, * <pA>0. */{ int ia = *ia0, ib = *ib0, nf = 0; if (ia >= mu && ib >= nu) return; if (ia == 0 && ib == 0) /* initial fetch, always get 2 */ { assert(ifetch >= 2);
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?