📄 ssyrk.h
字号:
#if (TRANSA==0)#if (A_ELEMS_PER_THREAD >= 3) ii = i + tidLo; IF (ii < parms.n) THEN offs2 = tidHi; for (ll = l + offs2; ll < llLimit; ll += COL_INCR) { AA[IDXAA(offs2,tidLo)] = parms.A[IDXA(ii,ll)]); offs2 += COL_INCR; } ENDIF#else /* A_ELEMS_PER_THREAD >= 3 */#if (A_ELEMS_PER_THREAD >= 1) ii = i + tidLo; IF (ii < parms.n) THEN ll = l + tidHi; IF (ll < llLimit) THEN unsigned int idxAA; unsigned int addrA; idxAA = IDXAA(tidHi,tidLo); addrA = IDXA(ii,ll); AA[idxAA] = parms.A[addrA];#if (A_ELEMS_PER_THREAD >= 2) ll += COL_INCR; IF (ll < llLimit) THEN idxAA += COL_INCR; addrA += COL_INCR * A_COL_OFS; AA[idxAA] = parms.A[addrA]; ENDIF#endif /* A_ELEMS_PER_THREAD >= 2 */ ENDIF ENDIF#endif /* A_ELEMS_PER_THREAD >= 1 */#endif /* A_ELEMS_PER_THREAD >= 3 */#else /* TRANSA = 0 */#if (A_ELEMS_PER_THREAD >= 3) ll = l + tidLo; IF (ll < llLimit) THEN unsigned int iiLimit = min (i + TILE_DIM, parms.n); offs2 = tidHi; for (ii = i + offs2; ii < iiLimit; ii += COL_INCR) { AA[IDXAA(tidLo, offs2)] = parms.A[IDXA(ll,ii)]; offs2 += COL_INCR; } ENDIF#else /* A_ELEMS_PER_THREAD >= 3 */#if (A_ELEMS_PER_THREAD >= 1) ll = l + tidLo; IF (ll < llLimit) THEN ii = i + tidHi; IF (ii < parms.n) THEN unsigned int idxAA; unsigned int addrA; idxAA = IDXAA(tidLo, tidHi); addrA = IDXA(ll, ii); AA[idxAA] = parms.A[addrA];#if (A_ELEMS_PER_THREAD >= 2) ii += COL_INCR; IF (ii < parms.n) THEN idxAA += COL_INCR * AA_COL_OFS; addrA += COL_INCR * A_COL_OFS; AA[idxAA] = parms.A[addrA]; ENDIF#endif /* A_ELEMS_PER_THREAD >= 2 */ ENDIF ENDIF#endif /* A_ELEMS_PER_THREAD >= 1 */#endif /* A_ELEMS_PER_THREAD >= 3 */#endif /* TRANSA = 0 */ #if (TRANSB==0)#if (B_ELEMS_PER_THREAD >= 3) ll = l + tidLo; IF (ll < llLimit) THEN unsigned int jjLimit = min (j + TILE_DIM, parms.n); offs2 = tidHi; for (jj = j + offs2; jj < jjLimit; jj += COL_INCR) { BB[IDXBB(tidLo, offs2)] = parms.B[IDXB(ll,jj)]; offs2 += COL_INCR; } ENDIF#else /* B_ELEMS_PER_THREAD >= 3 */#if (B_ELEMS_PER_THREAD >= 1) ll = l + tidLo; IF (ll < llLimit) THEN jj = j + tidHi; IF (jj < parms.n) THEN unsigned int idxBB; unsigned int addrB; idxBB = IDXBB(tidLo,tidHi); addrB = IDXB(ll,jj); BB[idxBB] = parms.B[addrB];#if (B_ELEMS_PER_THREAD >= 2) jj += COL_INCR; IF (jj < parms.n) THEN idxBB += COL_INCR * BB_COL_OFS; addrB += COL_INCR * B_COL_OFS; BB[idxBB] = parms.B[addrB]; ENDIF#endif /* B_ELEMS_PER_THREAD >= 2 */ ENDIF ENDIF#endif /* B_ELEMS_PER_THREAD >= 1 */#endif /* B_ELEMS_PER_THREAD >= 3 */#else /* TRANSB==0 */#if (B_ELEMS_PER_THREAD >= 3) jj = j + tidLo; IF (jj < parms.n) THEN offs2 = tidHi; for (ll = l + offs2; ll < llLimit; ll += COL_INCR) { BB[IDXBB(offs2, tidLo)] = parms.B[IDXB(jj,ll)]; offs2 += COL_INCR; } ENDIF#else /* B_ELEMS_PER_THREAD >= 3 */#if (B_ELEMS_PER_THREAD >= 1) jj = j + tidLo; IF (jj < parms.n) THEN ll = l + tidHi; IF (ll < llLimit) THEN unsigned int idxBB; unsigned int addrB; idxBB = IDXBB(tidHi,tidLo); addrB = IDXB(jj,ll); BB[idxBB] = parms.B[addrB];#if (B_ELEMS_PER_THREAD >= 2) ll += COL_INCR; IF (ll < llLimit) THEN idxBB += COL_INCR; addrB += COL_INCR * B_COL_OFS; BB[idxBB] = parms.B[addrB]; ENDIF#endif /* B_ELEMS_PER_THREAD >= 2 */ ENDIF ENDIF#endif /* B_ELEMS_PER_THREAD >= 1 */#endif /* B_ELEMS_PER_THREAD >= 3 */#endif /* TRANSB0==0 */ /* Wait until all elements of the A-tile and the B-tile have * been read, before any thread starts with the computation of * dot products */ __syncthreads (); /* For each of the result tile elements it needs to compute, a * thread computes the partial dot product by combining the * appropriate row (physically: column, due to transposition of * tile) of A-tile with the appropriate column of the B-tile. * In this case, each thread updates two dot products. Inline * checks prevent computation for result tile elements that do * not correspond to elements inside the result matrix. */ ii = tidLo; IF (ii < (parms.n - i)) THEN unsigned int z = llLimit - l; jj = tidHi; IF (z == 32) THEN#if (C_ELEMS_PER_THREAD == 1) IF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_DOT_PRODUCT_32(0); ENDIF#endif#if (C_ELEMS_PER_THREAD == 2) IF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_2DOT_PRODUCTS_32(0,1,BB_COL_OFS*COL_INCR); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ACCUMULATE_DOT_PRODUCT_32(0); ENDIF#endif#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 no supported#endif ELSE#if (C_ELEMS_PER_THREAD == 1) IF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); ENDIF#endif#if (C_ELEMS_PER_THREAD == 2) IF ((jj + COL_INCR) < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_2DOT_PRODUCTS_N(0,1,BB_COL_OFS*COL_INCR); ELSEIF (jj < (parms.n - j)) THEN unsigned int li = IDXAA(0,ii); unsigned int lj = IDXBB(0,jj); ll = z; ACCUMULATE_DOT_PRODUCT_N(0); ENDIF#endif#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 not supported#endif } } } /* At this point each thread has computed the dot product(s) that * represent each element of the result matrix tile (i.e. C-tile) * it is responsible for. If beta is zero, don't read the C-tile, * otherwise read the C-tile to scale it by beta. */ if (parms.beta == 0.0f) { ii = i + tidLo; jj = j + tidHi; IF ((ii < parms.n) && (jj < parms.n)) THEN unsigned int addrC = IDXC(ii,jj);#if (C_ELEMS_PER_THREAD >= 1)#if (UPPER==1) if (ii <= jj) {#else if (ii >= jj) {#endif parms.C[addrC] = parms.alpha * dp0; }#if (C_ELEMS_PER_THREAD >= 2) jj += COL_INCR; IF (jj < parms.n) THEN#if (UPPER==1) if (ii <= jj) {#else if (ii >= jj) {#endif addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.alpha * dp1; }#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 not supported#endif /* C_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 2 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 1 */ } else { ii = i + tidLo; jj = j + tidHi; IF ((ii < parms.n) && (jj < parms.n)) THEN unsigned int addrC = IDXC(ii,jj);#if (C_ELEMS_PER_THREAD >= 1)#if (UPPER==1) if (ii <= jj) {#else if (ii >= jj) {#endif parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp0; }#if (C_ELEMS_PER_THREAD >= 2) jj += COL_INCR; IF (jj < parms.n) THEN#if (UPPER==1) if (ii <= jj) {#else if (ii >= jj) {#endif addrC += COL_INCR * C_COL_OFS; parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp1; }#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 not supported#endif /* C_ELEMS_PER_THREAD >= 3 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 2 */ ENDIF#endif /* C_ELEMS_PER_THREAD >= 1 */ } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -