⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 sgemm.h

📁 Nividia提供的CUDA的BLAS库源码
💻 H
📖 第 1 页 / 共 3 页
字号:
    float dp2;#if (C_ELEMS_PER_THREAD >= 4)    float dp3;#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 not supported#endif /* (C_ELEMS_PER_THREAD >= 5) */#endif /* (C_ELEMS_PER_THREAD >= 4) */#endif /* (C_ELEMS_PER_THREAD >= 3) */#endif /* (C_ELEMS_PER_THREAD >= 2) */#endif /* (C_ELEMS_PER_THREAD >= 1) */    unsigned tidLo = (tid & (TILE_DIM - 1));    unsigned tidHi = (tid >> TILE_DIM_LOG);  #if ((TRANSA==1)||(TRANSB==0))    unsigned idxLoHi = IDXAA(tidLo,tidHi);#endif#if ((TRANSA==0)||(TRANSB==1))    unsigned idxHiLo = IDXAA(tidHi,tidLo);#endif#if (USE_MIXED_STEPPER == 1)    for (i = IMUL(blockIdx.y, TILE_DIM); i < parms.m; i += SUP_TILE_DIM) {        unsigned ii_1 = i + tidLo;#undef  ii_2#define ii_2 (i + tidHi)  /* could be induction variable if enough registers */        for (j = IMUL(blockIdx.x, TILE_DIM); j < parms.n; j += SUP_TILE_DIM) {#undef  jj_2#if ((TRANSB==0)&&(TRANSA==1))            unsigned jj_2 = j + tidHi;#else#define jj_2 (j + tidHi)  /* could be induction variable if enough registers */#endif#else /* USE_MIXED_STEPPER==1 */    {        {            i = IMUL(blockIdx.y, TILE_DIM);            j = IMUL(blockIdx.x, TILE_DIM);            unsigned ii_1 = i + tidLo;#undef  ii_2#if ((TRANSB==0)&&(TRANSA==1))            unsigned ii_2 = i + tidHi;#else#define ii_2 (i + tidHi)  /* could be induction variable if enough registers */#endif#undef  jj_2            unsigned jj_2 = j + tidHi;#endif /* USE_MIXED_STEPPER==1 */#undef  jj_1#if ((TRANSB==1)&&(TRANSA==0))            unsigned jj_1 = j + tidLo;#else#define jj_1 (j + tidLo)  /* could be induction variable if enough registers */#endif            /* set accumulation to 0*/#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 not supported#else#if (C_ELEMS_PER_THREAD >= 1)            dp0 = 0.0f;#if (C_ELEMS_PER_THREAD >= 2)            dp1 = 0.0f;#if (C_ELEMS_PER_THREAD >= 3)            dp2 = 0.0f;#if (C_ELEMS_PER_THREAD >= 4)            dp3 = 0.0f;#endif  /* C_ELEMS_PER_THREAD >= 4 */#endif  /* C_ELEMS_PER_THREAD >= 3 */#endif  /* C_ELEMS_PER_THREAD >= 2 */#endif  /* C_ELEMS_PER_THREAD >= 1 */#endif  /* C_ELEMS_PER_THREAD >= 5 */            for (l = 0; l < parms.k; l += TILE_DIM) {                unsigned int llLimit = min ((l + TILE_DIM), parms.k);#if ((A_ELEMS_PER_THREAD >= 5)||(B_ELEMS_PER_THREAD >= 5))                unsigned int offs2;#endif#undef ll_1#undef ll_2#if ((TRANSA==1)||(TRANSB==0))#define ll_1 (l + tidLo) /* could be induction variable if enough registers */#endif#if ((TRANSA==0)||(TRANSB==1))#define ll_2 (l + tidHi) /* could be induction variable if enough registers */#endif                              /* Wait before clobbering old cache contents */                __syncthreads ();                #if (TRANSA==0)#if (A_ELEMS_PER_THREAD >= 5)                ii = ii_1;                IF (ii < parms.m) THEN                    offs2 = tidHi;                    for (ll = ll_2; ll < llLimit; ll += COL_INCR) {                        AA[IDXAA(offs2,tidLo)] = fetchA(IDXA(ii,ll));                        offs2 += COL_INCR;                    }                ENDIF#else  /* A_ELEMS_PER_THREAD >= 5 */#if (A_ELEMS_PER_THREAD >= 1)                ll = ll_2;                IF ((ii_1 < parms.m) && (ll < llLimit)) THEN                    unsigned int idxAA;                    unsigned int addrA;                    idxAA = idxHiLo;                    addrA = IDXA(ii_1,ll);                    AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 2)                    ll += COL_INCR;                    idxAA += COL_INCR;                    addrA += COL_INCR * A_COL_OFS;                    IF (ll < llLimit) THEN                        AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 3)                        ll += COL_INCR;                        idxAA += COL_INCR;                        addrA += COL_INCR * A_COL_OFS;                        IF (ll < llLimit) THEN                            AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 4)                                                           ll += COL_INCR;                            idxAA += COL_INCR;                            addrA += COL_INCR * A_COL_OFS;                            IF (ll < llLimit) THEN                                AA[idxAA] = fetchA(addrA);                            ENDIF#endif  /* A_ELEMS_PER_THREAD >= 4 */                        ENDIF#endif  /* A_ELEMS_PER_THREAD >= 3 */                    ENDIF#endif  /* A_ELEMS_PER_THREAD >= 2 */                ENDIF#endif  /* A_ELEMS_PER_THREAD >= 1 */#endif  /* A_ELEMS_PER_THREAD >= 5 */#else   /* TRANSA = 0 */#if (A_ELEMS_PER_THREAD >= 5)                ll = ll_1;                IF (ll < llLimit) THEN                    unsigned int iiLimit = min (i + TILE_DIM, parms.m);                    offs2 = tidHi;                    for (ii = ii_2; ii < iiLimit; ii += COL_INCR) {                        AA[IDXAA(tidLo,offs2)] = fetchA(IDXA(ll,ii));                        offs2 += COL_INCR;                    }                ENDIF#else  /* A_ELEMS_PER_THREAD >= 5 */#if (A_ELEMS_PER_THREAD >= 1)                ii = ii_2;                IF ((ll_1 < llLimit) && (ii < parms.m)) THEN                    unsigned int idxAA;                    unsigned int addrA;                    idxAA = idxLoHi;                    addrA = IDXA(ll_1,ii);                    AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 2)                    ii += COL_INCR;                    idxAA += COL_INCR * AA_COL_OFS;                    addrA += COL_INCR * A_COL_OFS;                    IF (ii < parms.m) THEN                        AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 3)                        ii += COL_INCR;                        idxAA += COL_INCR * AA_COL_OFS;                        addrA += COL_INCR * A_COL_OFS;                                             IF (ii < parms.m) THEN                            AA[idxAA] = fetchA(addrA);#if (A_ELEMS_PER_THREAD >= 4)                                                           ii += COL_INCR;                            idxAA += COL_INCR * AA_COL_OFS;                            addrA += COL_INCR * A_COL_OFS;                            IF (ii < parms.m) THEN                                AA[idxAA] = fetchA(addrA);                            ENDIF#endif  /* A_ELEMS_PER_THREAD >= 4 */                        ENDIF#endif  /* A_ELEMS_PER_THREAD >= 3 */                    ENDIF#endif  /* A_ELEMS_PER_THREAD >= 2 */                ENDIF#endif  /* A_ELEMS_PER_THREAD >= 1 */#endif  /* A_ELEMS_PER_THREAD >= 5 */#endif  /* TRANSA = 0 */                #if (TRANSB==0)#if (B_ELEMS_PER_THREAD >= 5)                ll = ll_1;                IF (ll < llLimit) THEN                    unsigned int jjLimit = min (j + TILE_DIM, parms.n);                    offs2 = tidHi;                    for (jj = jj_2; jj < jjLimit; jj += COL_INCR) {                        BB[IDXBB(tidLo,offs2)] = fetchB(IDXB(ll,jj));                        offs2 += COL_INCR;                    }                ENDIF#else /* B_ELEMS_PER_THREAD >= 5 */#if (B_ELEMS_PER_THREAD >= 1)                jj = jj_2;                IF ((ll_1 < llLimit) && (jj < parms.n)) THEN                    unsigned int idxBB;                    unsigned int addrB;                    idxBB = idxLoHi;                    addrB = IDXB(ll_1,jj);                    BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 2)                    jj += COL_INCR;                    idxBB += COL_INCR * BB_COL_OFS;                    addrB += COL_INCR * B_COL_OFS;                    IF (jj < parms.n) THEN                        BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 3)                        jj += COL_INCR;                        idxBB += COL_INCR * BB_COL_OFS;                        addrB += COL_INCR * B_COL_OFS;                         IF (jj < parms.n) THEN                            BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 4)                            jj += COL_INCR;                            idxBB += COL_INCR * BB_COL_OFS;                            addrB += COL_INCR * B_COL_OFS;                             IF (jj < parms.n) THEN                                BB[idxBB] = fetchB(addrB);                            ENDIF#endif  /* B_ELEMS_PER_THREAD >= 4 */                        ENDIF                                    #endif  /* B_ELEMS_PER_THREAD >= 3 */                    ENDIF#endif  /* B_ELEMS_PER_THREAD >= 2 */                ENDIF#endif  /* B_ELEMS_PER_THREAD >= 1 */#endif  /* B_ELEMS_PER_THREAD >= 5 */#else   /* TRANSB==0 */#if (B_ELEMS_PER_THREAD >= 5)                jj = jj_1;                IF (jj < parms.n) THEN                    offs2 = tidHi;                    for (ll = ll_2; ll < llLimit; ll += COL_INCR) {                        BB[IDXBB(offs2,tidLo)] = fetchB(IDXB(jj,ll));                        offs2 += COL_INCR;                    }                ENDIF#else  /* B_ELEMS_PER_THREAD >= 5 */#if (B_ELEMS_PER_THREAD >= 1)                ll = ll_2;                IF ((jj_1 < parms.n) && (ll < llLimit)) THEN                    unsigned int idxBB;                    unsigned int addrB;                    idxBB = idxHiLo;                    addrB = IDXB(jj_1,ll);                    BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 2)                    ll += COL_INCR;                    idxBB += COL_INCR;                    addrB += COL_INCR * B_COL_OFS;                    IF (ll < llLimit) THEN                        BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 3)                        ll += COL_INCR;                        idxBB += COL_INCR;                        addrB += COL_INCR * B_COL_OFS;                        IF (ll < llLimit) THEN                            BB[idxBB] = fetchB(addrB);#if (B_ELEMS_PER_THREAD >= 4)                            ll += COL_INCR;                            idxBB += COL_INCR;                            addrB += COL_INCR * B_COL_OFS;                            IF (ll < llLimit) THEN                                BB[idxBB] = fetchB(addrB);                            ENDIF#endif  /* B_ELEMS_PER_THREAD >= 4 */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -