📄 sgemm.h
字号:
float dp2;#if (C_ELEMS_PER_THREAD >= 4) float dp3;#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 not supported#endif /* (C_ELEMS_PER_THREAD >= 5) */#endif /* (C_ELEMS_PER_THREAD >= 4) */#endif /* (C_ELEMS_PER_THREAD >= 3) */#endif /* (C_ELEMS_PER_THREAD >= 2) */#endif /* (C_ELEMS_PER_THREAD >= 1) */ unsigned tidLo = (tid & (TILE_DIM - 1)); unsigned tidHi = (tid >> TILE_DIM_LOG); #if ((TRANSA==1)||(TRANSB==0)) unsigned idxLoHi = IDXAA(tidLo,tidHi);#endif#if ((TRANSA==0)||(TRANSB==1)) unsigned idxHiLo = IDXAA(tidHi,tidLo);#endif#if (USE_MIXED_STEPPER == 1) for (i = IMUL(blockIdx.y, TILE_DIM); i < parms.m; i += SUP_TILE_DIM) { unsigned ii_1 = i + tidLo;#undef ii_2#define ii_2 (i + tidHi) /* could be induction variable if enough registers */ for (j = IMUL(blockIdx.x, TILE_DIM); j < parms.n; j += SUP_TILE_DIM) {#undef jj_2#if ((TRANSB==0)&&(TRANSA==1)) unsigned jj_2 = j + tidHi;#else#define jj_2 (j + tidHi) /* could be induction variable if enough registers */#endif#else /* USE_MIXED_STEPPER==1 */ { { i = IMUL(blockIdx.y, TILE_DIM); j = IMUL(blockIdx.x, TILE_DIM); unsigned ii_1 = i + tidLo;#undef ii_2#if ((TRANSB==0)&&(TRANSA==1)) unsigned ii_2 = i + tidHi;#else#define ii_2 (i + tidHi) /* could be induction variable if enough registers */#endif#undef jj_2 unsigned jj_2 = j + tidHi;#endif /* USE_MIXED_STEPPER==1 */#undef jj_1#if ((TRANSB==1)&&(TRANSA==0)) unsigned jj_1 = j + tidLo;#else#define jj_1 (j + tidLo) /* could be induction variable if enough registers */#endif /* set accumulation to 0*/#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 not supported#else#if (C_ELEMS_PER_THREAD >= 1) dp0 = 0.0f;#if (C_ELEMS_PER_THREAD >= 2) dp1 = 0.0f;#if (C_ELEMS_PER_THREAD >= 3) dp2 = 0.0f;#if (C_ELEMS_PER_THREAD >= 4) dp3 = 0.0f;#endif /* C_ELEMS_PER_THREAD >= 4 */#endif /* C_ELEMS_PER_THREAD >= 3 */#endif /* C_ELEMS_PER_THREAD >= 2 */#endif /* C_ELEMS_PER_THREAD >= 1 */#endif /* C_ELEMS_PER_THREAD >= 5 */ for (l = 0; l < parms.k; l += TILE_DIM) { unsigned int llLimit = min ((l + TILE_DIM), parms.k);#if ((A_ELEMS_PER_THREAD >= 5)||(B_ELEMS_PER_THREAD >= 5)) unsigned int offs2;#endif#undef ll_1#undef ll_2#if ((TRANSA==1)||(TRANSB==0))#define ll_1 (l + tidLo) /* could be induction variable if enough registers */#endif#if ((TRANSA==0)||(TRANSB==1))#define ll_2 (l + tidHi) /* could be induction variable if enough registers */#endif /* Wait before clobbering old cache contents */ __syncthreads (); #if (TRANSA==0)#if (A_ELEMS_PER_THREAD >= 5) ii = ii_1; IF (ii < parms.m) THEN offs2 = tidHi; for (ll = ll_2; ll < llLimit; ll += COL_INCR) { AA[IDXAA(offs2,tidLo)] = fetchA(IDXA(ii,ll)); offs2 += COL_INCR; } ENDIF#else /* A_ELEMS_PER_THREAD >= 5 */#if (A_ELEMS_PER_THREAD >= 1) ll = ll_2; IF ((ii_1 < parms.m) && (ll < llLimit)) THEN unsigned int idxAA; unsigned int addrA; idxAA = idxHiLo; addrA = IDXA(ii_1,ll); AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 2) ll += COL_INCR; idxAA += COL_INCR; addrA += COL_INCR * A_COL_OFS; IF (ll < llLimit) THEN AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 3) ll += COL_INCR; idxAA += COL_INCR; addrA += COL_INCR * A_COL_OFS; IF (ll < llLimit) THEN AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 4) ll += COL_INCR; idxAA += COL_INCR; addrA += COL_INCR * A_COL_OFS; IF (ll < llLimit) THEN AA[idxAA] = fetchA(addrA); ENDIF#endif /* A_ELEMS_PER_THREAD >= 4 */ ENDIF#endif /* A_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* A_ELEMS_PER_THREAD >= 2 */ ENDIF#endif /* A_ELEMS_PER_THREAD >= 1 */#endif /* A_ELEMS_PER_THREAD >= 5 */#else /* TRANSA = 0 */#if (A_ELEMS_PER_THREAD >= 5) ll = ll_1; IF (ll < llLimit) THEN unsigned int iiLimit = min (i + TILE_DIM, parms.m); offs2 = tidHi; for (ii = ii_2; ii < iiLimit; ii += COL_INCR) { AA[IDXAA(tidLo,offs2)] = fetchA(IDXA(ll,ii)); offs2 += COL_INCR; } ENDIF#else /* A_ELEMS_PER_THREAD >= 5 */#if (A_ELEMS_PER_THREAD >= 1) ii = ii_2; IF ((ll_1 < llLimit) && (ii < parms.m)) THEN unsigned int idxAA; unsigned int addrA; idxAA = idxLoHi; addrA = IDXA(ll_1,ii); AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 2) ii += COL_INCR; idxAA += COL_INCR * AA_COL_OFS; addrA += COL_INCR * A_COL_OFS; IF (ii < parms.m) THEN AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 3) ii += COL_INCR; idxAA += COL_INCR * AA_COL_OFS; addrA += COL_INCR * A_COL_OFS; IF (ii < parms.m) THEN AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 4) ii += COL_INCR; idxAA += COL_INCR * AA_COL_OFS; addrA += COL_INCR * A_COL_OFS; IF (ii < parms.m) THEN AA[idxAA] = fetchA(addrA); ENDIF#endif /* A_ELEMS_PER_THREAD >= 4 */ ENDIF#endif /* A_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* A_ELEMS_PER_THREAD >= 2 */ ENDIF#endif /* A_ELEMS_PER_THREAD >= 1 */#endif /* A_ELEMS_PER_THREAD >= 5 */#endif /* TRANSA = 0 */ #if (TRANSB==0)#if (B_ELEMS_PER_THREAD >= 5) ll = ll_1; IF (ll < llLimit) THEN unsigned int jjLimit = min (j + TILE_DIM, parms.n); offs2 = tidHi; for (jj = jj_2; jj < jjLimit; jj += COL_INCR) { BB[IDXBB(tidLo,offs2)] = fetchB(IDXB(ll,jj)); offs2 += COL_INCR; } ENDIF#else /* B_ELEMS_PER_THREAD >= 5 */#if (B_ELEMS_PER_THREAD >= 1) jj = jj_2; IF ((ll_1 < llLimit) && (jj < parms.n)) THEN unsigned int idxBB; unsigned int addrB; idxBB = idxLoHi; addrB = IDXB(ll_1,jj); BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 2) jj += COL_INCR; idxBB += COL_INCR * BB_COL_OFS; addrB += COL_INCR * B_COL_OFS; IF (jj < parms.n) THEN BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 3) jj += COL_INCR; idxBB += COL_INCR * BB_COL_OFS; addrB += COL_INCR * B_COL_OFS; IF (jj < parms.n) THEN BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 4) jj += COL_INCR; idxBB += COL_INCR * BB_COL_OFS; addrB += COL_INCR * B_COL_OFS; IF (jj < parms.n) THEN BB[idxBB] = fetchB(addrB); ENDIF#endif /* B_ELEMS_PER_THREAD >= 4 */ ENDIF #endif /* B_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* B_ELEMS_PER_THREAD >= 2 */ ENDIF#endif /* B_ELEMS_PER_THREAD >= 1 */#endif /* B_ELEMS_PER_THREAD >= 5 */#else /* TRANSB==0 */#if (B_ELEMS_PER_THREAD >= 5) jj = jj_1; IF (jj < parms.n) THEN offs2 = tidHi; for (ll = ll_2; ll < llLimit; ll += COL_INCR) { BB[IDXBB(offs2,tidLo)] = fetchB(IDXB(jj,ll)); offs2 += COL_INCR; } ENDIF#else /* B_ELEMS_PER_THREAD >= 5 */#if (B_ELEMS_PER_THREAD >= 1) ll = ll_2; IF ((jj_1 < parms.n) && (ll < llLimit)) THEN unsigned int idxBB; unsigned int addrB; idxBB = idxHiLo; addrB = IDXB(jj_1,ll); BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 2) ll += COL_INCR; idxBB += COL_INCR; addrB += COL_INCR * B_COL_OFS; IF (ll < llLimit) THEN BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 3) ll += COL_INCR; idxBB += COL_INCR; addrB += COL_INCR * B_COL_OFS; IF (ll < llLimit) THEN BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 4) ll += COL_INCR; idxBB += COL_INCR; addrB += COL_INCR * B_COL_OFS; IF (ll < llLimit) THEN BB[idxBB] = fetchB(addrB); ENDIF#endif /* B_ELEMS_PER_THREAD >= 4 */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -