📄 sgemm.h
字号:
ENDIF #endif /* B_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* B_ELEMS_PER_THREAD >= 2 */ ENDIF#endif /* B_ELEMS_PER_THREAD >= 1 */#endif /* B_ELEMS_PER_THREAD >= 5 */#endif /* TRANSB0==0 */ /* Wait until new cache contents ready */ __syncthreads (); /* We don't iterate over jj and ii since this is all done * in parallel by the threads in each CTA. */ ii = tidLo; IF (ii < (parms.m - i)) THEN unsigned int z = llLimit - l; jj = tidHi; IF (z == TILE_DIM) THEN#if (C_ELEMS_PER_THREAD == 1) IF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_DOT_PRODUCT_TILE(0); ENDIF#endif#if (C_ELEMS_PER_THREAD == 2) IF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_DOT_PRODUCT_TILE(0); ENDIF#endif#if (C_ELEMS_PER_THREAD == 3) IF ((jj + 2*COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR); lj += 2 * BB_COL_OFS * COL_INCR; ACCUMULATE_DOT_PRODUCT_TILE(2); ELSEIF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); CCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_DOT_PRODUCT_TILE(0); ENDIF#endif#if (C_ELEMS_PER_THREAD >= 4) IF ((jj + 3*COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR); lj += 2*BB_COL_OFS * COL_INCR; ACCUMULATE_2DOT_PRODUCTS_TILE(2,3,BB_COL_OFS*COL_INCR); ELSEIF ((jj + 2*COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR); lj += 2*BB_COL_OFS * COL_INCR; ACCUMULATE_DOT_PRODUCT_TILE(2); ELSEIF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_DOT_PRODUCT_TILE(0); ENDIF#endif#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 no supported#endif ELSE #if (C_ELEMS_PER_THREAD == 1) IF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); ENDIF#endif#if (C_ELEMS_PER_THREAD == 2) IF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_2DOT_PRODUCTS_N(0,1,BB_COL_OFS*COL_INCR); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); ENDIF#endif#if (C_ELEMS_PER_THREAD == 3) IF ((jj + 2*COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); jj += COL_INCR; li = IDXAA(0,ii); lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(1); jj += COL_INCR; li = IDXAA(0,ii); lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(2); ELSEIF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); li = IDXAA(0,ii); jj += COL_INCR; lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(1); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); ENDIF#endif #if (C_ELEMS_PER_THREAD == 4) IF ((jj + 3*COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); jj += COL_INCR; li = IDXAA(0,ii); lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(1); jj += COL_INCR; li = IDXAA(0,ii); lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(2); jj += COL_INCR; li = IDXAA(0,ii); lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(3); ELSEIF ((jj + 2*COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); jj += COL_INCR; li = IDXAA(0,ii); lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(1); jj += COL_INCR; li = IDXAA(0,ii); lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(2); ELSEIF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); li = IDXAA(0,ii); jj += COL_INCR; lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(1); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); ENDIF#endif#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 no supported#endif ENDIF ENDIF } /* write out completed tile of matrix C */ if (parms.beta == 0.0f) {#if (C_ELEMS_PER_THREAD >= 5) /* we would need an array dp[] instead of scalar dp0, .. */#error C_ELEMS_PER_THREAD >= 5 no supported#else#if (C_ELEMS_PER_THREAD >= 1) ii = ii_1; jj = jj_2; IF ((ii < parms.m) && (jj < parms.n)) THEN unsigned int addrC = IDXC(ii,jj); parms.C[addrC] = parms.alpha * dp0;#if (C_ELEMS_PER_THREAD >= 2) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.alpha * dp1;#if (C_ELEMS_PER_THREAD >= 3) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.alpha * dp2;#if (C_ELEMS_PER_THREAD >= 4) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.alpha * dp3; ENDIF#endif /* C_ELEMS_PER_THREAD >= 4 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 2 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 1 */#endif } else {#if (C_ELEMS_PER_THREAD >= 5) /* we would need an array dp[] instead of scalar dp0, .. */#error C_ELEMS_PER_THREAD >= 5 no supported#else#if (C_ELEMS_PER_THREAD >= 1) ii = ii_1; jj = jj_2; IF ((ii < parms.m) && (jj < parms.n)) THEN unsigned int addrC = IDXC(ii,jj); parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp0;#if (C_ELEMS_PER_THREAD >= 2) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp1;#if (C_ELEMS_PER_THREAD >= 3) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp2;#if (C_ELEMS_PER_THREAD >= 4) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp3; ENDIF#endif /* C_ELEMS_PER_THREAD >= 4 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 2 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 1 */#endif } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -