📄 ssymm.h
字号:
i = IMUL(blockIdx.y,TILE_DIM); j = IMUL(blockIdx.x,TILE_DIM);#endif /* set accumulation to 0*/#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 not supported#else#if (C_ELEMS_PER_THREAD >= 1) dp0 = 0.0f;#if (C_ELEMS_PER_THREAD >= 2) dp1 = 0.0f;#if (C_ELEMS_PER_THREAD >= 3) dp2 = 0.0f;#if (C_ELEMS_PER_THREAD >= 4) dp3 = 0.0f;#endif /* C_ELEMS_PER_THREAD >= 4 */#endif /* C_ELEMS_PER_THREAD >= 3 */#endif /* C_ELEMS_PER_THREAD >= 2 */#endif /* C_ELEMS_PER_THREAD >= 1 */#endif /* C_ELEMS_PER_THREAD >= 5 */ for (l = 0; l < parms.k; l += TILE_DIM) { unsigned int llLimit = min ((l + TILE_DIM), parms.k); unsigned int z = llLimit - l; unsigned int offs2; /* Wait before clobbering old cache contents */ __syncthreads (); ii = i + tidLo; IF (ii < parms.m) THEN offs2 = tidHi;#if (LSIDE==1)#if (UPPER==1) for (ll = l+offs2; ll<llLimit; ll += COL_INCR) { unsigned int addr = (ii<=ll)?IDXA(ii,ll):IDXA(ll,ii); AA[IDXAA(offs2,tidLo)] = parms.A[addr]; offs2 += COL_INCR; }#else for (ll = l+offs2; ll<llLimit; ll += COL_INCR) { unsigned int addr = (ii>=ll)?IDXA(ii,ll):IDXA(ll,ii); AA[IDXAA(offs2,tidLo)] = parms.A[addr]; offs2 += COL_INCR; }#endif#else /* LSIDE==1 */ for (ll = l + offs2; ll < llLimit; ll += COL_INCR) { AA[IDXAA(offs2,tidLo)] = parms.A[IDXA(ii,ll)]; offs2 += COL_INCR; }#endif /* LSIDE==1 */ ENDIF ll = l + tidLo; IF (ll < llLimit) THEN unsigned int jjLimit = min (j + TILE_DIM, parms.n); offs2 = tidHi;#if (LSIDE==1) for (jj = j + offs2; jj < jjLimit; jj += COL_INCR) { BB[IDXBB(tidLo, offs2)] = parms.B[IDXB(ll,jj)]; offs2 += COL_INCR; }#else#if (UPPER==1) for (jj = j + offs2; jj < jjLimit; jj += COL_INCR) { unsigned int addr = (ll<=jj)?IDXB(ll,jj):IDXB(jj,ll); BB[IDXBB(tidLo, offs2)] = parms.B[addr]; offs2 += COL_INCR; }#else for (jj = j + offs2; jj < jjLimit; jj += COL_INCR) { unsigned int addr = (jj<=ll)?IDXB(ll,jj):IDXB(jj,ll); BB[IDXBB(tidLo, offs2)] = parms.B[addr]; offs2 += COL_INCR; }#endif#endif /* LSIDE==1 */ ENDIF /* Wait until new cache contents ready */ __syncthreads (); /* We don't iterate over jj and ii since this is al done * in paralel by the threads in each CTA. */ ii = tidLo; IF (ii < (parms.m - i)) THEN jj = tidHi; IF (z == 32) THEN#if (C_ELEMS_PER_THREAD == 1) IF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_DOT_PRODUCT_32(0); ENDIF#endif#if (C_ELEMS_PER_THREAD == 2) IF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_2DOT_PRODUCTS_32(0,1,BB_COL_OFS*COL_INCR); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_DOT_PRODUCT_32(0); ENDIF#endif#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 no supported#endif ELSE #if (C_ELEMS_PER_THREAD == 1) IF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); ENDIF#endif#if (C_ELEMS_PER_THREAD == 2) IF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_2DOT_PRODUCTS_N(0,1,BB_COL_OFS*COL_INCR); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); ENDIF#endif#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 no supported#endif ENDIF ENDIF } /* write out completed tile of matrix C */ if (parms.beta == 0.0f) {#if (C_ELEMS_PER_THREAD >= 5) /* we would need an array dp[] instead of scalar dp0, .. */#error C_ELEMS_PER_THREAD >= 5 no supported#else#if (C_ELEMS_PER_THREAD >= 1) ii = i + tidLo; IF (ii < parms.m) THEN jj = j + tidHi; IF (jj < parms.n) THEN unsigned int addrC = IDXC(ii,jj); parms.C[addrC] = parms.alpha * dp0;#if (C_ELEMS_PER_THREAD >= 2) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.alpha * dp1;#if (C_ELEMS_PER_THREAD >= 3) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.alpha * dp2;#if (C_ELEMS_PER_THREAD >= 4) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.alpha * dp3; ENDIF#endif /* C_ELEMS_PER_THREAD >= 4 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 2 */ ENDIF ENDIF#endif /* C_ELEMS_PER_THREAD >= 1 */#endif } else {#if (C_ELEMS_PER_THREAD >= 5) /* we would need an array dp[] instead of scalar dp0, .. */#error C_ELEMS_PER_THREAD >= 5 no supported#else#if (C_ELEMS_PER_THREAD >= 1) ii = i + tidLo; IF (ii < parms.m) THEN jj = j + tidHi; IF (jj < parms.n) THEN unsigned int addrC = IDXC(ii,jj); parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp0;#if (C_ELEMS_PER_THREAD >= 2) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp1;#if (C_ELEMS_PER_THREAD >= 3) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp2;#if (C_ELEMS_PER_THREAD >= 4) jj += COL_INCR; IF (jj < parms.n) THEN addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp3; ENDIF#endif /* C_ELEMS_PER_THREAD >= 4 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 2 */ ENDIF ENDIF#endif /* C_ELEMS_PER_THREAD >= 1 */#endif } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -