atl_gemvn_16x4_1.c

来自「基于Blas CLapck的.用过的人知道是干啥的」· C语言 代码 · 共 595 行 · 第 1/2 页

C
595
字号
      y[8] = z0;      y[9] = z1;      y[10] = z2;      y[11] = z3;      y[12] = z4;      y[13] = z5;      y[14] = z6;      y[15] = z7;      if (M-M16) gemvMlt8(M-M16, N, A0+16, lda, x, beta, y+16);   }   else if (N) gemvMlt8(M, N, A, lda, x, beta, y);}static void gemv32x4(const int M, const int N, const TYPE *A, const int lda,                     const TYPE *X, const SCALAR beta, TYPE *Y){   #ifdef BETA1      int j;   #endif   const int incA = lda<<2;   if (N >= 4)   {      if (M >= 32)      {         #ifdef BETA1            for (j=(N>>2); j; j--, A += incA, X += 4)               gemvN32x4(M, 4, A, lda, X, ATL_rone, Y);            if ( (j = N-((N>>2)<<2)) ) gemvNle4(M, j, A, lda, X, ATL_rone, Y);         #else            gemvN32x4(M, 4, A, lda, X, beta, Y);            if (N != 4)               Mjoin(PATL,gemvN_a1_x1_b1_y1)                  (M, N-4, ATL_rone, A+incA, lda, X+4, 1, ATL_rone, Y, 1);         #endif      }      else gemvMlt8(M, N, A, lda, X, beta, Y);   }   else if (M) gemvNle4(M, N, A, lda, X, beta, Y);}static void gemvMlt8(const int M, const int N, const TYPE *A, const int lda,                     const TYPE *X, const SCALAR beta, TYPE *Y){   int i;   register TYPE y0;   for (i=M; i; i--)   {      #ifdef BETA0         y0 = Mjoin(PATL,dot)(N, A, lda, X, 1);      #else         Yget(y0, *Y, beta);         y0 += Mjoin(PATL,dot)(N, A, lda, X, 1);      #endif      *Y++ = y0;      A++;   }}static void gemvNle4(const int M, const int N, const TYPE *A, const int lda,                     const TYPE *X, const SCALAR beta, TYPE *Y){   int i;   const TYPE *A0 = A, *A1 = A+lda, *A2 = A1+lda, *A3 = A2+lda;   register TYPE x0, x1, x2, x3;   #ifdef BETAX      const register TYPE bet=beta;   #endif   switch(N)   {   case 1:      #if defined(BETA0)         Mjoin(PATL,cpsc)(M, *X, A, 1, Y, 1);      #elif defined(BETAX)         Mjoin(PATL,axpby)(M, *X, A, 1, beta, Y, 1);      #else         Mjoin(PATL,axpy)(M, *X, A, 1, Y, 1);      #endif      break;   case 2:      x0 = *X; x1 = X[1];      for (i=0; i != M; i++)      #ifdef BETA0         Y[i] = A0[i] * x0 + A1[i] * x1;      #elif defined(BETAX)         Y[i] = Y[i]*bet + A0[i] * x0 + A1[i] * x1;      #else         Y[i] += A0[i] * x0 + A1[i] * x1;      #endif      break;   case 3:      x0 = *X; x1 = X[1]; x2 = X[2];      for (i=0; i != M; i++)      #ifdef BETA0         Y[i] = A0[i] * x0 + A1[i] * x1 + A2[i] * x2;      #elif defined(BETAX)         Y[i] = Y[i]*bet + A0[i] * x0 + A1[i] * x1 + A2[i] * x2;      #else         Y[i] += A0[i] * x0 + A1[i] * x1 + A2[i] * x2;      #endif      break;   case 4:      if (M >= 32) gemv32x4(M, 4, A, lda, X, beta, Y);      else      {         x0 = *X; x1 = X[1]; x2 = X[2]; x3 = X[3];         for (i=0; i != M; i++)         #ifdef BETA0            Y[i] = A0[i] * x0 + A1[i] * x1 + A2[i] * x2 + A3[i] * x3;         #elif defined(BETAX)            Y[i] = Y[i]*bet + A0[i] * x0 + A1[i] * x1 + A2[i] * x2 + A3[i] * x3;         #else            Y[i] += A0[i] * x0 + A1[i] * x1 + A2[i] * x2 + A3[i] * x3;         #endif      }      break;   default:      ATL_assert(!N);   }}static void gemv16x4(const int M, const int N, const TYPE *A, const int lda,                     const TYPE *X, const SCALAR beta, TYPE *Y)/* * 16x4 with feeble prefetch */{   int j;   const int M16 = (M>>4)<<4, N4 = (N>>2)<<2, nr = N-N4+4;   const int incA = lda << 2, incAm = 16 - (N4-4)*lda, incAm0 = 16 - N*lda;   const TYPE *stX = X + N4, *x;   const TYPE *A0 = A, *A1 = A + lda, *A2 = A1 + lda, *A3 = A2 + lda;   TYPE *stY = Y + M16;   register TYPE x0, x1, x2, x3;   register TYPE y0, y1, y2, y3, y4, y5, y6, y7;   register TYPE y8, y9, y10, y11, y12, y13, y14, y15;   register TYPE p0, p1, p2, p3;   if (N > 4)   {      if (M16)      {         do         {            #ifdef BETA0               y0 = y1 = y2 = y3 = y4 = y5 = y6 = y7 =               y8 = y9 = y10 = y11 = y12 = y13 = y14 = y15 = ATL_rzero;            #elif defined BETAX               x0 = beta;               y0 = *Y; y1 = Y[1]; y2 = Y[2];  y3 = Y[3];               y8 = Y[8]; y9 = Y[9]; y10 = Y[10];  y11 = Y[11];               y4 = Y[4]; y5 = Y[5]; y6 = Y[6];  y7 = Y[7];               y12 = Y[12]; y13 = Y[13]; y14 = Y[14];  y15 = Y[15];               y0 *= x0; y1 *= x0; y2 *= x0; y3 *= x0;               y8 *= x0; y9 *= x0; y10 *= x0; y11 *= x0;               y4 *= x0; y5 *= x0; y6 *= x0; y7 *= x0;               y12 *= x0; y13 *= x0; y14 *= x0; y15 *= x0;            #else               y0 = *Y; y1 = Y[1]; y2 = Y[2];  y3 = Y[3];               y8 = Y[8]; y9 = Y[9]; y10 = Y[10];  y11 = Y[11];               y4 = Y[4]; y5 = Y[5]; y6 = Y[6];  y7 = Y[7];               y12 = Y[12]; y13 = Y[13]; y14 = Y[14];  y15 = Y[15];            #endif            p0 = *A0; p1 = A1[1];            p2 = A2[2]; p3 = A3[3];            x0 = *X; x1 = X[1]; x2 = X[2]; x3 = X[3];            x = X + 4;            if (N4 != 4)            {               do               {                  y0  += x0 * p0; p0 = A0[incA];                  y1  += x1 * p1; p1 = A1[incA+1];                  y2  += x2 * p2; p2 = A2[incA+2];                  y3  += x3 * p3; p3 = A3[incA+3];                  y8  += x0 * A0[8];                  y9  += x1 * A1[9];                  y10 += x2 * A2[10];                  y11 += x3 * A3[11];                  y4  += x0 * A0[4];                  y5  += x1 * A1[5];                  y6  += x2 * A2[6];                  y7  += x3 * A3[7];                  y12 += x0 * A0[12];                  y13 += x1 * A1[13];                  y14 += x2 * A2[14];                  y15 += x3 * A3[15];                  y0  += x1 * *A1;                  y1  += x0 * A0[1];                  y2  += x0 * A0[2];                  y3  += x0 * A0[3];                  y8  += x1 * A1[8];                  y9  += x0 * A0[9];                  y10 += x0 * A0[10];                  y11 += x0 * A0[11];                  y4  += x1 * A1[4];                  y5  += x0 * A0[5];                  y6  += x0 * A0[6];                  y7  += x0 * A0[7];                  y12 += x1 * A1[12];                  y13 += x0 * A0[13];                  y14 += x0 * A0[14];                  y15 += x0 * A0[15]; x0 = *x;                  y0  += x2 * *A2;                  y1  += x2 * A2[1]; A0 += incA;                  y2  += x1 * A1[2];                  y3  += x1 * A1[3];                  y8  += x2 * A2[8];                  y9  += x2 * A2[9];                  y10 += x1 * A1[10];                  y11 += x1 * A1[11];                  y4  += x2 * A2[4];                  y5  += x2 * A2[5];                  y6  += x1 * A1[6];                  y7  += x1 * A1[7];                  y12 += x2 * A2[12];                  y13 += x2 * A2[13];                  y14 += x1 * A1[14];                  y15 += x1 * A1[15]; x1 = x[1];                  y0  += x3 * *A3;                  y1  += x3 * A3[1];                  y2  += x3 * A3[2]; A1 += incA;                  y3  += x2 * A2[3];                  y8  += x3 * A3[8];                  y9  += x3 * A3[9];                  y10 += x3 * A3[10];                  y11 += x2 * A2[11];                  y4  += x3 * A3[4];                  y5  += x3 * A3[5];                  y6  += x3 * A3[6];                  y7  += x2 * A2[7];                  y12 += x3 * A3[12];                  y13 += x3 * A3[13];                  y14 += x3 * A3[14]; x3 = x[3]; A3 += incA;                  y15 += x2 * A2[15]; x2 = x[2]; x += 4; A2 += incA;               }               while (x != stX);            }            x -= 4;            for (j=0; j != nr; j++, A0 += lda)            {               x0 = x[j];               y0  += x0 * *A0;               y1  += x0 * A0[1];               y2  += x0 * A0[2];               y3  += x0 * A0[3];               y8  += x0 * A0[8];               y9  += x0 * A0[9];               y10 += x0 * A0[10];               y11 += x0 * A0[11];               y4  += x0 * A0[4];               y5  += x0 * A0[5];               y6  += x0 * A0[6];               y7  += x0 * A0[7];               y12 += x0 * A0[12];               y13 += x0 * A0[13];               y14 += x0 * A0[14];               y15 += x0 * A0[15];            }            A0 += incAm0;            *Y   = y0;            Y[ 1] = y1 ;            Y[ 2] = y2 ;            Y[ 3] = y3 ;            A1 += incAm;            Y[ 8] = y8 ;            Y[ 9] = y9 ;            Y[10] = y10;            Y[11] = y11;            A2 += incAm;            Y[ 4] = y4 ;            Y[ 5] = y5 ;            Y[ 6] = y6 ;            Y[ 7] = y7 ;            A3 += incAm;            Y[12] = y12;            Y[13] = y13;            Y[14] = y14;            Y[15] = y15;            Y += 16;         }         while (Y != stY);      }      if (M-M16) gemvMlt8(M-M16, N, A0, lda, X, beta, Y);   }   else if (M) gemvNle4(M, N, A, lda, X, beta, Y);}void Mjoin(Mjoin(Mjoin(Mjoin(Mjoin(PATL,gemvN),NM),_x1),BNM),_y1)   (const int M, const int N, const SCALAR alpha, const TYPE *A, const int lda,    const TYPE *X, const int incX, const SCALAR beta, TYPE *Y, const int incY){   gemv16x4(M, N, A, lda, X, beta, Y);}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?