⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 sgemm.h

📁 Nividia提供的CUDA的BLAS库源码
💻 H
📖 第 1 页 / 共 3 页
字号:
                        ENDIF                                    #endif  /* B_ELEMS_PER_THREAD >= 3 */                    ENDIF#endif  /* B_ELEMS_PER_THREAD >= 2 */                ENDIF#endif  /* B_ELEMS_PER_THREAD >= 1 */#endif  /* B_ELEMS_PER_THREAD >= 5 */#endif  /* TRANSB0==0 */                                /* Wait until new cache contents ready */                __syncthreads ();                                /* We don't iterate over jj and ii since this is all done                 * in parallel by the threads in each CTA.                 */                ii = tidLo;                IF (ii < (parms.m - i)) THEN                    unsigned int z = llLimit - l;                    jj = tidHi;                    IF (z == TILE_DIM) THEN#if (C_ELEMS_PER_THREAD == 1)                        IF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_DOT_PRODUCT_TILE(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD == 2)                        IF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_DOT_PRODUCT_TILE(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD == 3)                        IF ((jj + 2*COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR);                            lj += 2 * BB_COL_OFS * COL_INCR;                            ACCUMULATE_DOT_PRODUCT_TILE(2);                        ELSEIF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            CCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_DOT_PRODUCT_TILE(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD >= 4)                        IF ((jj + 3*COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR);                            lj += 2*BB_COL_OFS * COL_INCR;                            ACCUMULATE_2DOT_PRODUCTS_TILE(2,3,BB_COL_OFS*COL_INCR);                        ELSEIF ((jj + 2*COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR);                            lj += 2*BB_COL_OFS * COL_INCR;                            ACCUMULATE_DOT_PRODUCT_TILE(2);                        ELSEIF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_2DOT_PRODUCTS_TILE(0,1,BB_COL_OFS*COL_INCR);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ACCUMULATE_DOT_PRODUCT_TILE(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 no supported#endif                    ELSE                        #if (C_ELEMS_PER_THREAD == 1)                        IF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD == 2)                        IF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_2DOT_PRODUCTS_N(0,1,BB_COL_OFS*COL_INCR);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD == 3)                        IF ((jj + 2*COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                            jj += COL_INCR;                            li = IDXAA(0,ii);                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(1);                            jj += COL_INCR;                            li = IDXAA(0,ii);                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(2);                        ELSEIF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                            li = IDXAA(0,ii);                            jj += COL_INCR;                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(1);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                        ENDIF#endif                        #if (C_ELEMS_PER_THREAD == 4)                        IF ((jj + 3*COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                            jj += COL_INCR;                            li = IDXAA(0,ii);                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(1);                            jj += COL_INCR;                            li = IDXAA(0,ii);                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(2);                            jj += COL_INCR;                            li = IDXAA(0,ii);                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(3);                        ELSEIF ((jj + 2*COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                            jj += COL_INCR;                            li = IDXAA(0,ii);                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(1);                            jj += COL_INCR;                            li = IDXAA(0,ii);                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(2);                        ELSEIF ((jj + COL_INCR) < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                            li = IDXAA(0,ii);                            jj += COL_INCR;                            lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(1);                        ELSEIF (jj < (parms.n - j)) THEN                            unsigned int li = IDXAA(0,ii);                            unsigned int lj = IDXBB(0,jj);                            ll = z;                            ACCUMULATE_DOT_PRODUCT_N(0);                        ENDIF#endif#if (C_ELEMS_PER_THREAD >= 5)#error C_ELEMS_PER_THREAD >= 5 no supported#endif                    ENDIF                ENDIF            }            /* write out completed tile of matrix C */            if (parms.beta == 0.0f) {#if (C_ELEMS_PER_THREAD >= 5)                /* we would need an array dp[] instead of scalar dp0, .. */#error C_ELEMS_PER_THREAD >= 5 no supported#else#if (C_ELEMS_PER_THREAD >= 1)                ii = ii_1;                jj = jj_2;                IF ((ii < parms.m) && (jj < parms.n)) THEN                    unsigned int addrC = IDXC(ii,jj);                    parms.C[addrC] = parms.alpha * dp0;#if (C_ELEMS_PER_THREAD >= 2)                    jj += COL_INCR;                    IF (jj < parms.n) THEN                        addrC += COL_INCR * C_COL_OFS;                        parms.C[addrC] = parms.alpha * dp1;#if (C_ELEMS_PER_THREAD >= 3)                        jj += COL_INCR;                        IF (jj < parms.n) THEN                            addrC += COL_INCR * C_COL_OFS;                            parms.C[addrC] = parms.alpha * dp2;#if (C_ELEMS_PER_THREAD >= 4)                            jj += COL_INCR;                            IF (jj < parms.n) THEN                                addrC += COL_INCR * C_COL_OFS;                                parms.C[addrC] = parms.alpha * dp3;                            ENDIF#endif  /* C_ELEMS_PER_THREAD >= 4 */                        ENDIF#endif  /* C_ELEMS_PER_THREAD >= 3 */                    ENDIF#endif  /* C_ELEMS_PER_THREAD >= 2 */                ENDIF#endif  /* C_ELEMS_PER_THREAD >= 1 */#endif            } else {#if (C_ELEMS_PER_THREAD >= 5)                /* we would need an array dp[] instead of scalar dp0, .. */#error C_ELEMS_PER_THREAD >= 5 no supported#else#if (C_ELEMS_PER_THREAD >= 1)                ii = ii_1;                jj = jj_2;                IF ((ii < parms.m) && (jj < parms.n)) THEN                    unsigned int addrC = IDXC(ii,jj);                    parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp0;#if (C_ELEMS_PER_THREAD >= 2)                    jj += COL_INCR;                    IF (jj < parms.n) THEN                        addrC += COL_INCR * C_COL_OFS;                        parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp1;#if (C_ELEMS_PER_THREAD >= 3)                        jj += COL_INCR;                        IF (jj < parms.n) THEN                            addrC += COL_INCR * C_COL_OFS;                            parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp2;#if (C_ELEMS_PER_THREAD >= 4)                            jj += COL_INCR;                            IF (jj < parms.n) THEN                                addrC += COL_INCR * C_COL_OFS;                                parms.C[addrC] = parms.beta * parms.C[addrC] + parms.alpha * dp3;                            ENDIF#endif  /* C_ELEMS_PER_THREAD >= 4 */                        ENDIF#endif  /* C_ELEMS_PER_THREAD >= 3 */                    ENDIF#endif  /* C_ELEMS_PER_THREAD >= 2 */                ENDIF#endif  /* C_ELEMS_PER_THREAD >= 1 */#endif            }        }    }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -