⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ssyrk.h

📁 Nividia提供的CUDA的BLAS库源码
💻 H
📖 第 1 页 / 共 2 页
字号:
#if (TRANSA==0)#if (A_ELEMS_PER_THREAD >= 3)                ii = i + tidLo;                IF (ii < parms.n) THEN                    offs2 = tidHi;                    for (ll = l + offs2; ll < llLimit; ll += COL_INCR) {                        AA[IDXAA(offs2,tidLo)] = parms.A[IDXA(ii,ll)]);                        offs2 += COL_INCR;                    }                ENDIF#else  /* A_ELEMS_PER_THREAD >= 3 */#if (A_ELEMS_PER_THREAD >= 1)                ii = i + tidLo;                IF (ii < parms.n) THEN                    ll = l + tidHi;                    IF (ll < llLimit) THEN                        unsigned int idxAA;                        unsigned int addrA;                        idxAA = IDXAA(tidHi,tidLo);                        addrA = IDXA(ii,ll);                        AA[idxAA] = parms.A[addrA];#if (A_ELEMS_PER_THREAD >= 2)                        ll += COL_INCR;                        IF (ll < llLimit) THEN                            idxAA += COL_INCR;                            addrA += COL_INCR * A_COL_OFS;                            AA[idxAA] = parms.A[addrA];                        ENDIF#endif  /* A_ELEMS_PER_THREAD >= 2 */                    ENDIF                ENDIF#endif  /* A_ELEMS_PER_THREAD >= 1 */#endif  /* A_ELEMS_PER_THREAD >= 3 */#else   /* TRANSA = 0 */#if (A_ELEMS_PER_THREAD >= 3)                ll = l + tidLo;                 IF (ll < llLimit) THEN                    unsigned int iiLimit = min (i + TILE_DIM, parms.n);                    offs2 = tidHi;                    for (ii = i + offs2; ii < iiLimit; ii += COL_INCR) {                        AA[IDXAA(tidLo, offs2)] = parms.A[IDXA(ll,ii)];                        offs2 += COL_INCR;                    }                ENDIF#else  /* A_ELEMS_PER_THREAD >= 3 */#if (A_ELEMS_PER_THREAD >= 1)                ll = l + tidLo;                IF (ll < llLimit) THEN                    ii = i + tidHi;                    IF (ii < parms.n) THEN                        unsigned int idxAA;                        unsigned int addrA;                        idxAA = IDXAA(tidLo, tidHi);                        addrA = IDXA(ll, ii);                        AA[idxAA] = parms.A[addrA];#if (A_ELEMS_PER_THREAD >= 2)                        ii += COL_INCR;                        IF (ii < parms.n) THEN                            idxAA += COL_INCR * AA_COL_OFS;                            addrA += COL_INCR * A_COL_OFS;                            AA[idxAA] = parms.A[addrA];                        ENDIF#endif  /* A_ELEMS_PER_THREAD >= 2 */                    ENDIF                ENDIF#endif  /* A_ELEMS_PER_THREAD >= 1 */#endif  /* A_ELEMS_PER_THREAD >= 3 */#endif  /* TRANSA = 0 */                #if (TRANSB==0)#if (B_ELEMS_PER_THREAD >= 3)                ll = l + tidLo;                IF (ll < llLimit) THEN                    unsigned int jjLimit = min (j + TILE_DIM, parms.n);                    offs2 = tidHi;                    for (jj = j + offs2; jj < jjLimit; jj += COL_INCR) {                        BB[IDXBB(tidLo, offs2)] = parms.B[IDXB(ll,jj)];                        offs2 += COL_INCR;                    }                ENDIF#else /* B_ELEMS_PER_THREAD >= 3 */#if (B_ELEMS_PER_THREAD >= 1)                ll = l + tidLo;                IF (ll < llLimit) THEN                    jj = j + tidHi;                    IF (jj < parms.n) THEN                        unsigned int idxBB;                        unsigned int addrB;                        idxBB = IDXBB(tidLo,tidHi);                        addrB = IDXB(ll,jj);                        BB[idxBB] = parms.B[addrB];#if (B_ELEMS_PER_THREAD >= 2)                        jj += COL_INCR;                        IF (jj < parms.n) THEN                            idxBB += COL_INCR * BB_COL_OFS;                            addrB += COL_INCR * B_COL_OFS;                            BB[idxBB] = parms.B[addrB];                        ENDIF#endif  /* B_ELEMS_PER_THREAD >= 2 */                    ENDIF                ENDIF#endif  /* B_ELEMS_PER_THREAD >= 1 */#endif  /* B_ELEMS_PER_THREAD >= 3 */#else   /* TRANSB==0 */#if (B_ELEMS_PER_THREAD >= 3)                jj = j + tidLo;                IF (jj < parms.n) THEN                    offs2 = tidHi;                    for (ll = l + offs2; ll < llLimit; ll += COL_INCR) {                        BB[IDXBB(offs2, tidLo)] = parms.B[IDXB(jj,ll)];                        offs2 += COL_INCR;                    }                ENDIF#else  /* B_ELEMS_PER_THREAD >= 3 */#if (B_ELEMS_PER_THREAD >= 1)                jj = j + tidLo;                IF (jj < parms.n) THEN                    ll = l + tidHi;                    IF (ll < llLimit) THEN                        unsigned int idxBB;                        unsigned int addrB;                        idxBB = IDXBB(tidHi,tidLo);                        addrB = IDXB(jj,ll);                        BB[idxBB] = parms.B[addrB];#if (B_ELEMS_PER_THREAD >= 2)                        ll += COL_INCR;                        IF (ll < llLimit) THEN                            idxBB += COL_INCR;                            addrB += COL_INCR * B_COL_OFS;                            BB[idxBB] = parms.B[addrB];                        ENDIF#endif  /* B_ELEMS_PER_THREAD >= 2 */                    ENDIF                ENDIF#endif  /* B_ELEMS_PER_THREAD >= 1 */#endif  /* B_ELEMS_PER_THREAD >= 3 */#endif  /* TRANSB0==0 */                                /* Wait until all elements of the A-tile and the B-tile have                 * been read, before any thread starts with the computation of                 * dot products                 */                __syncthreads ();                                /* For each of the result tile elements it needs to compute, a                 * thread computes the partial dot product by combining the                 * appropriate row (physically: column, due to transposition of                 * tile) of A-tile with the appropriate column of the B-tile.                 * In this case, each thread updates two dot products. Inline                 * checks prevent computation for result tile elements that do                 * not correspond to elements inside the result matrix.                 */                ii = tidLo;                IF (ii < (parms.n - i)) THEN                    unsigned int z = llLimit - l;                    jj = tidHi;                    IF (z == 32) THEN#if (C_ELEMS_PER_THREAD == 1)                        IF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_DOT_PRODUCT_32(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD == 2)                        IF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_2DOT_PRODUCTS_32(0,1,BB_COL_OFS*COL_INCR);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_DOT_PRODUCT_32(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 no supported#endif                    ELSE#if (C_ELEMS_PER_THREAD == 1)                        IF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD == 2)                        IF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_2DOT_PRODUCTS_N(0,1,BB_COL_OFS*COL_INCR);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 not supported#endif                    }                }            }            /* At this point each thread has computed the dot product(s) that             * represent each element of the result matrix tile (i.e. C-tile)              * it is responsible for. If beta is zero, don't read the C-tile,             * otherwise read the C-tile to scale it by beta.             */            if (parms.beta == 0.0f) {                ii = i + tidLo;                jj = j + tidHi;                IF ((ii < parms.n) && (jj < parms.n)) THEN                unsigned int addrC = IDXC(ii,jj);#if (C_ELEMS_PER_THREAD >= 1)#if (UPPER==1)                    if (ii <= jj) {#else                    if (ii >= jj) {#endif                        parms.C[addrC] = parms.alpha * dp0;                    }#if (C_ELEMS_PER_THREAD >= 2)                    jj += COL_INCR;                    IF (jj < parms.n) THEN#if (UPPER==1)                        if (ii <= jj) {#else                        if (ii >= jj) {#endif                            addrC += COL_INCR * C_COL_OFS;                            parms.C[addrC] = parms.alpha * dp1;                        }#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 not supported#endif /* C_ELEMS_PER_THREAD >= 3 */                    ENDIF#endif /* C_ELEMS_PER_THREAD >= 2 */                ENDIF#endif /* C_ELEMS_PER_THREAD >= 1 */            } else {                ii = i + tidLo;                jj = j + tidHi;                IF ((ii < parms.n) && (jj < parms.n)) THEN                unsigned int addrC = IDXC(ii,jj);#if (C_ELEMS_PER_THREAD >= 1)#if (UPPER==1)                    if (ii <= jj) {#else                    if (ii >= jj) {#endif                        parms.C[addrC] = parms.beta * parms.C[addrC] +                                          parms.alpha * dp0;                    }#if (C_ELEMS_PER_THREAD >= 2)                    jj += COL_INCR;                    IF (jj < parms.n) THEN#if (UPPER==1)                        if (ii <= jj) {#else                        if (ii >= jj) {#endif                            addrC += COL_INCR * C_COL_OFS;                            parms.C[addrC] = parms.beta * parms.C[addrC] +                                              parms.alpha * dp1;                        }#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 not supported#endif /* C_ELEMS_PER_THREAD >= 3 */                    ENDIF#endif /* C_ELEMS_PER_THREAD >= 2 */                ENDIF#endif /* C_ELEMS_PER_THREAD >= 1 */            }        }    }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -