📄 sgemm.h
字号:
/* * Copyright 1993-2008 NVIDIA Corporation. All rights reserved. * * NOTICE TO USER: * * This source code is subject to NVIDIA ownership rights under U.S. and * international Copyright laws. * * This software and the information contained herein is being provided * under the terms and conditions of a Source Code License Agreement. * * NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE * CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR * IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH * REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. * IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL, * OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS * OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE * OR PERFORMANCE OF THIS SOURCE CODE. * * U.S. Government End Users. This source code is a "commercial item" as * that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of * "commercial computer software" and "commercial computer software * documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995) * and is provided to the U.S. Government only as a commercial end item. * Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through * 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the * source code with only those rights set forth herein. *//* Index functions for accessing surce/result matrices in GMEM, and cached * tiles in GRF. All matrices use column-major ordering. */#if (FAST_IMUL==1)#undef IMUL#define IMUL(x,y) __umul24(x,y)#else#undef IMUL#define IMUL(x,y) ((x)*(y))#endif#define IDXA(row,col) (IMUL(parms.lda,col)+(row)) /* index into matrix A */#define IDXB(row,col) (IMUL(parms.ldb,col)+(row)) /* index into matrix B */#define IDXC(row,col) (IMUL(parms.ldc,col)+(row)) /* index into matrix C */#define IDXAA(row,col) (__umul24(TILE_DIM+1,col)+(row)) /* index GRF A-tile */#define IDXBB(row,col) (__umul24(TILE_DIM+1,col)+(row)) /* index GRF B-tile */#define AA_COL_OFS (IDXAA(0,1)-IDXAA(0,0))#define BB_COL_OFS (IDXBB(0,1)-IDXBB(0,0))#define A_COL_OFS (IDXA(0,1)-IDXA(0,0))#define B_COL_OFS (IDXB(0,1)-IDXB(0,0))#define C_COL_OFS (IDXC(0,1)-IDXC(0,0))#if (USE_TEX==1)#undef fetchA#undef fetchB#define fetchA(i) tex1Dfetch(texA,(int)(parms.texAOfs+(i)))#define fetchB(i) tex1Dfetch(texB,(int)(parms.texBOfs+(i)))#else#undef fetchA#undef fetchB#define fetchA(i) parms.A[i]#define fetchB(i) parms.B[i]#endif#if (TILE_DIM==32)#define ACCUMULATE_DOT_PRODUCT_TILE(num) \do { \ dp##num += (AA[li+ 0] * BB[lj+ 0]); \ dp##num += (AA[li+ 1] * BB[lj+ 1]); \ dp##num += (AA[li+ 2] * BB[lj+ 2]); \ dp##num += (AA[li+ 3] * BB[lj+ 3]); \ dp##num += (AA[li+ 4] * BB[lj+ 4]); \ dp##num += (AA[li+ 5] * BB[lj+ 5]); \ dp##num += (AA[li+ 6] * BB[lj+ 6]); \ dp##num += (AA[li+ 7] * BB[lj+ 7]); \ dp##num += (AA[li+ 8] * BB[lj+ 8]); \ dp##num += (AA[li+ 9] * BB[lj+ 9]); \ dp##num += (AA[li+10] * BB[lj+10]); \ dp##num += (AA[li+11] * BB[lj+11]); \ dp##num += (AA[li+12] * BB[lj+12]); \ dp##num += (AA[li+13] * BB[lj+13]); \ dp##num += (AA[li+14] * BB[lj+14]); \ dp##num += (AA[li+15] * BB[lj+15]); \ dp##num += (AA[li+16] * BB[lj+16]); \ dp##num += (AA[li+17] * BB[lj+17]); \ dp##num += (AA[li+18] * BB[lj+18]); \ dp##num += (AA[li+19] * BB[lj+19]); \ dp##num += (AA[li+20] * BB[lj+20]); \ dp##num += (AA[li+21] * BB[lj+21]); \ dp##num += (AA[li+22] * BB[lj+22]); \ dp##num += (AA[li+23] * BB[lj+23]); \ dp##num += (AA[li+24] * BB[lj+24]); \ dp##num += (AA[li+25] * BB[lj+25]); \ dp##num += (AA[li+26] * BB[lj+26]); \ dp##num += (AA[li+27] * BB[lj+27]); \ dp##num += (AA[li+28] * BB[lj+28]); \ dp##num += (AA[li+29] * BB[lj+29]); \ dp##num += (AA[li+30] * BB[lj+30]); \ dp##num += (AA[li+31] * BB[lj+31]); \} while (0)#define ACCUMULATE_2DOT_PRODUCTS_TILE(num1,num2,ljOfs) \do { \ dp##num1 += (AA[li+ 0] * BB[lj+ 0]); \ dp##num2 += (AA[li+ 0] * BB[lj+(ljOfs)+ 0]); \ dp##num1 += (AA[li+ 1] * BB[lj+ 1]); \ dp##num2 += (AA[li+ 1] * BB[lj+(ljOfs)+ 1]); \ dp##num1 += (AA[li+ 2] * BB[lj+ 2]); \ dp##num2 += (AA[li+ 2] * BB[lj+(ljOfs)+ 2]); \ dp##num1 += (AA[li+ 3] * BB[lj+ 3]); \ dp##num2 += (AA[li+ 3] * BB[lj+(ljOfs)+ 3]); \ dp##num1 += (AA[li+ 4] * BB[lj+ 4]); \ dp##num2 += (AA[li+ 4] * BB[lj+(ljOfs)+ 4]); \ dp##num1 += (AA[li+ 5] * BB[lj+ 5]); \ dp##num2 += (AA[li+ 5] * BB[lj+(ljOfs)+ 5]); \ dp##num1 += (AA[li+ 6] * BB[lj+ 6]); \ dp##num2 += (AA[li+ 6] * BB[lj+(ljOfs)+ 6]); \ dp##num1 += (AA[li+ 7] * BB[lj+ 7]); \ dp##num2 += (AA[li+ 7] * BB[lj+(ljOfs)+ 7]); \ dp##num1 += (AA[li+ 8] * BB[lj+ 8]); \ dp##num2 += (AA[li+ 8] * BB[lj+(ljOfs)+ 8]); \ dp##num1 += (AA[li+ 9] * BB[lj+ 9]); \ dp##num2 += (AA[li+ 9] * BB[lj+(ljOfs)+ 9]); \ dp##num1 += (AA[li+10] * BB[lj+10]); \ dp##num2 += (AA[li+10] * BB[lj+(ljOfs)+10]); \ dp##num1 += (AA[li+11] * BB[lj+11]); \ dp##num2 += (AA[li+11] * BB[lj+(ljOfs)+11]); \ dp##num1 += (AA[li+12] * BB[lj+12]); \ dp##num2 += (AA[li+12] * BB[lj+(ljOfs)+12]); \ dp##num1 += (AA[li+13] * BB[lj+13]); \ dp##num2 += (AA[li+13] * BB[lj+(ljOfs)+13]); \ dp##num1 += (AA[li+14] * BB[lj+14]); \ dp##num2 += (AA[li+14] * BB[lj+(ljOfs)+14]); \ dp##num1 += (AA[li+15] * BB[lj+15]); \ dp##num2 += (AA[li+15] * BB[lj+(ljOfs)+15]); \ dp##num1 += (AA[li+16] * BB[lj+16]); \ dp##num2 += (AA[li+16] * BB[lj+(ljOfs)+16]); \ dp##num1 += (AA[li+17] * BB[lj+17]); \ dp##num2 += (AA[li+17] * BB[lj+(ljOfs)+17]); \ dp##num1 += (AA[li+18] * BB[lj+18]); \ dp##num2 += (AA[li+18] * BB[lj+(ljOfs)+18]); \ dp##num1 += (AA[li+19] * BB[lj+19]); \ dp##num2 += (AA[li+19] * BB[lj+(ljOfs)+19]); \ dp##num1 += (AA[li+20] * BB[lj+20]); \ dp##num2 += (AA[li+20] * BB[lj+(ljOfs)+20]); \ dp##num1 += (AA[li+21] * BB[lj+21]); \ dp##num2 += (AA[li+21] * BB[lj+(ljOfs)+21]); \ dp##num1 += (AA[li+22] * BB[lj+22]); \ dp##num2 += (AA[li+22] * BB[lj+(ljOfs)+22]); \ dp##num1 += (AA[li+23] * BB[lj+23]); \ dp##num2 += (AA[li+23] * BB[lj+(ljOfs)+23]); \ dp##num1 += (AA[li+24] * BB[lj+24]); \ dp##num2 += (AA[li+24] * BB[lj+(ljOfs)+24]); \ dp##num1 += (AA[li+25] * BB[lj+25]); \ dp##num2 += (AA[li+25] * BB[lj+(ljOfs)+25]); \ dp##num1 += (AA[li+26] * BB[lj+26]); \ dp##num2 += (AA[li+26] * BB[lj+(ljOfs)+26]); \ dp##num1 += (AA[li+27] * BB[lj+27]); \ dp##num2 += (AA[li+27] * BB[lj+(ljOfs)+27]); \ dp##num1 += (AA[li+28] * BB[lj+28]); \ dp##num2 += (AA[li+28] * BB[lj+(ljOfs)+28]); \ dp##num1 += (AA[li+29] * BB[lj+29]); \ dp##num2 += (AA[li+29] * BB[lj+(ljOfs)+29]); \ dp##num1 += (AA[li+30] * BB[lj+30]); \ dp##num2 += (AA[li+30] * BB[lj+(ljOfs)+30]); \ dp##num1 += (AA[li+31] * BB[lj+31]); \ dp##num2 += (AA[li+31] * BB[lj+(ljOfs)+31]); \} while (0)#elif (TILE_DIM==16)#define ACCUMULATE_DOT_PRODUCT_TILE(num) \do { \ dp##num += (AA[li+ 0] * BB[lj+ 0]); \ dp##num += (AA[li+ 1] * BB[lj+ 1]); \ dp##num += (AA[li+ 2] * BB[lj+ 2]); \ dp##num += (AA[li+ 3] * BB[lj+ 3]); \ dp##num += (AA[li+ 4] * BB[lj+ 4]); \ dp##num += (AA[li+ 5] * BB[lj+ 5]); \ dp##num += (AA[li+ 6] * BB[lj+ 6]); \ dp##num += (AA[li+ 7] * BB[lj+ 7]); \ dp##num += (AA[li+ 8] * BB[lj+ 8]); \ dp##num += (AA[li+ 9] * BB[lj+ 9]); \ dp##num += (AA[li+10] * BB[lj+10]); \ dp##num += (AA[li+11] * BB[lj+11]); \ dp##num += (AA[li+12] * BB[lj+12]); \ dp##num += (AA[li+13] * BB[lj+13]); \ dp##num += (AA[li+14] * BB[lj+14]); \ dp##num += (AA[li+15] * BB[lj+15]); \} while (0)#define ACCUMULATE_2DOT_PRODUCTS_TILE(num1,num2,ljOfs) \do { \ dp##num1 += (AA[li+ 0] * BB[lj+ 0]); \ dp##num2 += (AA[li+ 0] * BB[lj+(ljOfs)+ 0]); \ dp##num1 += (AA[li+ 1] * BB[lj+ 1]); \ dp##num2 += (AA[li+ 1] * BB[lj+(ljOfs)+ 1]); \ dp##num1 += (AA[li+ 2] * BB[lj+ 2]); \ dp##num2 += (AA[li+ 2] * BB[lj+(ljOfs)+ 2]); \ dp##num1 += (AA[li+ 3] * BB[lj+ 3]); \ dp##num2 += (AA[li+ 3] * BB[lj+(ljOfs)+ 3]); \ dp##num1 += (AA[li+ 4] * BB[lj+ 4]); \ dp##num2 += (AA[li+ 4] * BB[lj+(ljOfs)+ 4]); \ dp##num1 += (AA[li+ 5] * BB[lj+ 5]); \ dp##num2 += (AA[li+ 5] * BB[lj+(ljOfs)+ 5]); \ dp##num1 += (AA[li+ 6] * BB[lj+ 6]); \ dp##num2 += (AA[li+ 6] * BB[lj+(ljOfs)+ 6]); \ dp##num1 += (AA[li+ 7] * BB[lj+ 7]); \ dp##num2 += (AA[li+ 7] * BB[lj+(ljOfs)+ 7]); \ dp##num1 += (AA[li+ 8] * BB[lj+ 8]); \ dp##num2 += (AA[li+ 8] * BB[lj+(ljOfs)+ 8]); \ dp##num1 += (AA[li+ 9] * BB[lj+ 9]); \ dp##num2 += (AA[li+ 9] * BB[lj+(ljOfs)+ 9]); \ dp##num1 += (AA[li+10] * BB[lj+10]); \ dp##num2 += (AA[li+10] * BB[lj+(ljOfs)+10]); \ dp##num1 += (AA[li+11] * BB[lj+11]); \ dp##num2 += (AA[li+11] * BB[lj+(ljOfs)+11]); \ dp##num1 += (AA[li+12] * BB[lj+12]); \ dp##num2 += (AA[li+12] * BB[lj+(ljOfs)+12]); \ dp##num1 += (AA[li+13] * BB[lj+13]); \ dp##num2 += (AA[li+13] * BB[lj+(ljOfs)+13]); \ dp##num1 += (AA[li+14] * BB[lj+14]); \ dp##num2 += (AA[li+14] * BB[lj+(ljOfs)+14]); \ dp##num1 += (AA[li+15] * BB[lj+15]); \ dp##num2 += (AA[li+15] * BB[lj+(ljOfs)+15]); \} while (0)#else#error TILE_DIM must be 16 or 32#endif#define ACCUMULATE_DOT_PRODUCT_N(num) \do { \ do { \ dp##num += (AA[li] * BB[lj]); \ li++; \ lj++; \ ll--; \ } while (ll); \} while (0)#define ACCUMULATE_2DOT_PRODUCTS_N(num1,num2,ljOfs) \do { \ do { \ dp##num1 += (AA[li+ 0] * BB[lj+ 0]); \ dp##num2 += (AA[li+ 0] * BB[lj+(ljOfs)+ 0]); \ li++; \ lj++; \ ll--; \ } while(ll); \} while (0)#undef IF#undef THEN#undef ENDIF#undef ELSE#undef ELSEIF#if FULL_TILES_ONLY==1#define IF(x)#define THEN {#define ENDIF }#define ELSE } if (0) {#define ELSEIF(x) } if (0)#else#define IF(x) if (x)#define THEN {#define ENDIF }#define ELSE } else {#define ELSEIF(x) } else if (x)#endif unsigned int i, j, l, ii, jj, ll, tid = threadIdx.x;#if (C_ELEMS_PER_THREAD >= 1) float dp0;#if (C_ELEMS_PER_THREAD >= 2) float dp1;#if (C_ELEMS_PER_THREAD >= 3)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -