⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ssymm.h

📁 Nividia提供的CUDA的BLAS库源码
💻 H
📖 第 1 页 / 共 2 页
字号:
            i = IMUL(blockIdx.y,TILE_DIM);            j = IMUL(blockIdx.x,TILE_DIM);#endif            /* set accumulation to 0*/#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 not supported#else#if (C_ELEMS_PER_THREAD >= 1)            dp0 = 0.0f;#if (C_ELEMS_PER_THREAD >= 2)            dp1 = 0.0f;#if (C_ELEMS_PER_THREAD >= 3)            dp2 = 0.0f;#if (C_ELEMS_PER_THREAD >= 4)            dp3 = 0.0f;#endif  /* C_ELEMS_PER_THREAD >= 4 */#endif  /* C_ELEMS_PER_THREAD >= 3 */#endif  /* C_ELEMS_PER_THREAD >= 2 */#endif  /* C_ELEMS_PER_THREAD >= 1 */#endif  /* C_ELEMS_PER_THREAD >= 5 */            for (l = 0; l < parms.k; l += TILE_DIM) {                unsigned int llLimit = min ((l + TILE_DIM), parms.k);                unsigned int z = llLimit - l;                unsigned int offs2;                                /* Wait before clobbering old cache contents */                __syncthreads ();                                ii = i + tidLo;                IF (ii < parms.m) THEN                    offs2 = tidHi;#if (LSIDE==1)#if (UPPER==1)                    for (ll = l+offs2; ll<llLimit; ll += COL_INCR) {                        unsigned int addr = (ii<=ll)?IDXA(ii,ll):IDXA(ll,ii);                        AA[IDXAA(offs2,tidLo)] = parms.A[addr];                        offs2 += COL_INCR;                    }#else                    for (ll = l+offs2; ll<llLimit; ll += COL_INCR) {                        unsigned int addr = (ii>=ll)?IDXA(ii,ll):IDXA(ll,ii);                        AA[IDXAA(offs2,tidLo)] = parms.A[addr];                        offs2 += COL_INCR;                    }#endif#else /* LSIDE==1 */                    for (ll = l + offs2; ll < llLimit; ll += COL_INCR) {                        AA[IDXAA(offs2,tidLo)] = parms.A[IDXA(ii,ll)];                        offs2 += COL_INCR;                    }#endif /* LSIDE==1 */                ENDIF                ll = l + tidLo;                IF (ll < llLimit) THEN                    unsigned int jjLimit = min (j + TILE_DIM, parms.n);                    offs2 = tidHi;#if (LSIDE==1)                    for (jj = j + offs2; jj < jjLimit; jj += COL_INCR) {                        BB[IDXBB(tidLo, offs2)] = parms.B[IDXB(ll,jj)];                        offs2 += COL_INCR;                    }#else#if (UPPER==1)                    for (jj = j + offs2; jj < jjLimit; jj += COL_INCR) {                        unsigned int addr = (ll<=jj)?IDXB(ll,jj):IDXB(jj,ll);                        BB[IDXBB(tidLo, offs2)] = parms.B[addr];                        offs2 += COL_INCR;                    }#else                    for (jj = j + offs2; jj < jjLimit; jj += COL_INCR) {                        unsigned int addr = (jj<=ll)?IDXB(ll,jj):IDXB(jj,ll);                        BB[IDXBB(tidLo, offs2)] = parms.B[addr];                        offs2 += COL_INCR;                    }#endif#endif /* LSIDE==1 */                ENDIF                /* Wait until new cache contents ready */                __syncthreads ();                                /* We don't iterate over jj and ii since this is al done                 * in paralel by the threads in each CTA.                 */                ii = tidLo;                IF (ii < (parms.m - i)) THEN                    jj = tidHi;                    IF (z == 32) THEN#if (C_ELEMS_PER_THREAD == 1)                        IF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_DOT_PRODUCT_32(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD == 2)                        IF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_2DOT_PRODUCTS_32(0,1,BB_COL_OFS*COL_INCR);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_DOT_PRODUCT_32(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 no supported#endif                    ELSE                        #if (C_ELEMS_PER_THREAD == 1)                        IF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD == 2)                        IF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_2DOT_PRODUCTS_N(0,1,BB_COL_OFS*COL_INCR);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD >= 3)#error C_ELEMS_PER_THREAD >= 3 no supported#endif                    ENDIF                ENDIF            }            /* write out completed tile of matrix C */            if (parms.beta == 0.0f) {#if (C_ELEMS_PER_THREAD >= 5)                /* we would need an array dp[] instead of scalar dp0, .. */#error C_ELEMS_PER_THREAD >= 5 no supported#else#if (C_ELEMS_PER_THREAD >= 1)                ii = i + tidLo;                IF (ii < parms.m) THEN                    jj = j + tidHi;                    IF (jj < parms.n) THEN                        unsigned int addrC = IDXC(ii,jj);                        parms.C[addrC] = parms.alpha * dp0;#if (C_ELEMS_PER_THREAD >= 2)                        jj += COL_INCR;                        IF (jj < parms.n) THEN                            addrC += COL_INCR * C_COL_OFS;                            parms.C[addrC] = parms.alpha * dp1;#if (C_ELEMS_PER_THREAD >= 3)                            jj += COL_INCR;                            IF (jj < parms.n) THEN                                addrC += COL_INCR * C_COL_OFS;                                parms.C[addrC] = parms.alpha * dp2;#if (C_ELEMS_PER_THREAD >= 4)                                jj += COL_INCR;                                IF (jj < parms.n) THEN                                    addrC += COL_INCR * C_COL_OFS;                                    parms.C[addrC] = parms.alpha * dp3;                                ENDIF#endif  /* C_ELEMS_PER_THREAD >= 4 */                            ENDIF#endif  /* C_ELEMS_PER_THREAD >= 3 */                        ENDIF#endif  /* C_ELEMS_PER_THREAD >= 2 */                    ENDIF                ENDIF#endif  /* C_ELEMS_PER_THREAD >= 1 */#endif            } else {#if (C_ELEMS_PER_THREAD >= 5)                /* we would need an array dp[] instead of scalar dp0, .. */#error C_ELEMS_PER_THREAD >= 5 no supported#else#if (C_ELEMS_PER_THREAD >= 1)                ii = i + tidLo;                IF (ii < parms.m) THEN                    jj = j + tidHi;                    IF (jj < parms.n) THEN                        unsigned int addrC = IDXC(ii,jj);                        parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp0;#if (C_ELEMS_PER_THREAD >= 2)                        jj += COL_INCR;                        IF (jj < parms.n) THEN                            addrC += COL_INCR * C_COL_OFS;                            parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp1;#if (C_ELEMS_PER_THREAD >= 3)                            jj += COL_INCR;                            IF (jj < parms.n) THEN                                addrC += COL_INCR * C_COL_OFS;                                parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp2;#if (C_ELEMS_PER_THREAD >= 4)                                jj += COL_INCR;                                IF (jj < parms.n) THEN                                    addrC += COL_INCR * C_COL_OFS;                                    parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp3;                                ENDIF#endif  /* C_ELEMS_PER_THREAD >= 4 */                            ENDIF#endif  /* C_ELEMS_PER_THREAD >= 3 */                        ENDIF#endif  /* C_ELEMS_PER_THREAD >= 2 */                    ENDIF                ENDIF#endif  /* C_ELEMS_PER_THREAD >= 1 */#endif            }        }    }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -