⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 strassen_internal.c

📁 C++编写的高性能矩阵乘法的Stranssen算法
💻 C
字号:
/*============================================================================Internal Strassen (fmm) routine used in the recursion.  It only works on even-sized matrices.Inputs  transa,transb : characters specifying form of A or B (transpose or not)  m,n,k         : matrix dimensions, all even  alpha,beta    : scalars  A             : m x k matrix  B             : k x n matrix  ldb,lda.ldc   : matrix leading dimensions  d_aux         : temporary work space  i_naux        : amount of temporary work spaceOutputs  C             : m x n matrix, C = alpha*A*B+beta*C============================================================================*/#include "matrix.h"void strassen_internal(char c_transa,char c_transb,int m,int n,int k,		       double alpha, double *a,int lda,double *b,int ldb,		       double beta,double *c,int ldc,double *d_aux,int i_naux){  double *a11, *a12, *a21, *a22,  /* Quadrants of A, B & C             */         *b11, *b12, *b21, *b22,         *c11, *c12, *c21, *c22,         *r1,  *r2,  *r3;         /* Temporary matrices                */  int ldr1, ldr2, ldr3,           /* Leading dimensions of temporaries */      m_half, n_half, k_half,     /* Half of matrix dimensions         */      rows_a, cols_a,             /* Rows/cols of half-sized matrix A  */      rows_b, cols_b;             /* Rows/cols of half-sized matrix B  */  /* Exit this routine quickly if matrix dimension(s) too small */  if (no_recursion(m, n, k, alpha, beta)) {    /* Multiply small matrices using conventional matrix multiply */    matrix_prod(c_transa,c_transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);    }  else {    /* Recursion */    m_half = m/2;    n_half = n/2;    k_half = k/2;    /* Set pointers to the appropriate 4 submatrices of A, B, & C based on     * the settings of c_transa and c_transb.     */    submatrix(a,lda,1,1,&a11);    if (c_transa == 'T') {      submatrix(a,lda,1,m_half+1,&a21);      submatrix(a,lda,k_half+1,1,&a12);      submatrix(a,lda,k_half+1,m_half+1,&a22);    }    else {      submatrix(a,lda,1,k_half+1,&a12);      submatrix(a,lda,m_half+1,1,&a21);      submatrix(a,lda,m_half+1,k_half+1,&a22);    }    submatrix(b,ldb,1,1,&b11);    if (c_transb == 'T') {      submatrix(b,ldb,1,k_half+1,&b21);      submatrix(b,ldb,n_half+1,1,&b12);      submatrix(b,ldb,n_half+1,k_half+1,&b22);    }    else {      submatrix(b,ldb,1,n_half+1,&b12);      submatrix(b,ldb,k_half+1,1,&b21);      submatrix(b,ldb,k_half+1,n_half+1,&b22);    }        submatrix(c,ldc,1,1,&c11);    submatrix(c,ldc,1,n_half+1,&c12);    submatrix(c,ldc,m_half+1,1,&c21);    submatrix(c,ldc,m_half+1,n_half+1,&c22);    /* This is Winograd's variant of Strassen's algorithm.  The sequence of      * operations has been altered to reduce the amount of temporary storage.     */    /* Transpose cases are handled automatically by multiply routine (fmm),     * since we pass the appropriate operation via the character parameter.       * Add/subtract routines do not use such a parameter, so the correct       * row/column orientation has to be preserved in the call.     */     if (beta == 0.0) {         /* Since beta = 0.0, initial values in C are ignored, and we can use 	* the 4 quadrants of C as temporary storage.  Only 2 additional 	* temporary arrays are needed.	*/       /* Partition the work space d_aux into submatrices R1 & R2, where	* R1 has dimension m_half x MAX(k_half, n_half) if c_transa = N, or 	* MAX(k_half x m_half, m_half x n_half), if c_transa = T,	* and R2 has dimension k_half x n_half, if c_transb = N,	* or n_half x k_half, 	* if c_transb = T. Set leading dimensions appropriately.  At the end,	* d_aux points to the next unused memory location.	*/       r1 = d_aux;       r2 = r1 + m_half*MAX(k_half,n_half);       d_aux = r2 + k_half*n_half;       if (c_transa == 'T') {	 ldr1 = k_half;  /* R1 is used as a work array in 2 different  */	 ldr3 = m_half;  /* places, so 2 leading dimensions are needed */	 rows_a = k_half;	 cols_a = m_half;       }       else {	 ldr1 = ldr3 = m_half;	 rows_a = m_half;	 cols_a = k_half;       }              if (c_transb == 'T') {	 ldr2 = n_half;	 rows_b = n_half;	 cols_b = k_half;       }       else {	 ldr2 = k_half; 	 rows_b = k_half;	 cols_b = n_half;       }       /* R1 = alpha * (a11 - a21) */       matrix_sub(rows_a,cols_a,alpha,a11,lda,a21,lda,r1,ldr1);       /* R2 = b22 - b12 */       matrix_sub(rows_b,cols_b,1.0,b22,ldb,b12,ldb,r2,ldr2);       /* c11 = R1 * R2 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2,		    0.0,c11,ldc,d_aux,i_naux);       /* R1 = alpha * (a21 + a22) */       matrix_add(rows_a,cols_a,alpha,a21,lda,a22,lda,r1,ldr1);       /* R2 = b12 - b11 */       matrix_sub(rows_b,cols_b,1.0,b12,ldb,b11,ldb,r2,ldr2);       /* c22 = R1 * R2 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2,		    0.0,c22,ldc,d_aux,i_naux);       /* R1 = R1 - (alpha * a11) */       matrix_lin_update(rows_a,cols_a,-alpha,a11,lda,1.0,r1,ldr1);       /* R2 = b22 - R2 */       matrix_sub(rows_b,cols_b,1.0,b22,ldb,r2,ldr2,r2,ldr2);       /* c21 = R1 * R2 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2,		    0.0,c21,ldc,d_aux,i_naux);       /* R1 = alpha * a12 - R1 */       matrix_lin_update(rows_a,cols_a,alpha,a12,lda,-1.0,r1,ldr1);       /* R2 = b21 - R2 */       matrix_sub(rows_b,cols_b,1.0,b21,ldb,r2,ldr2,r2,ldr2);       /* c12 = R1 * b22 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,b22,ldb,		    0.0,c12,ldc,d_aux,i_naux);       /* c12 = c12 + c22 */       matrix_add(m_half,n_half,1.0,c22,ldc,c12,ldc,c12,ldc);       /* R1 = alpha * a11 * b11 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a11,lda,b11,		    ldb,0.0,r1,ldr3,d_aux,i_naux);       /* c21 = c21 + R1 */       matrix_add(m_half,n_half,1.0,r1,ldr3,c21,ldc,c21,ldc);       /* c12 = c12 + c21 */       matrix_add(m_half,n_half,1.0,c21,ldc,c12,ldc,c12,ldc);       /* c21 = c21 + c11 */       matrix_add(m_half,n_half,1.0,c11,ldc,c21,ldc,c21,ldc);       /* c22 = c22 + c21 */       matrix_add(m_half,n_half,1.0,c21,ldc,c22,ldc,c22,ldc);       /* c11 = alpha * a12 * b21 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a12,lda,b21,		    ldb,0.0,c11,ldc,d_aux,i_naux);       /* c11 = c11 + R1 */       matrix_add(m_half,n_half,1.0,r1,ldr3,c11,ldc,c11,ldc);       /* R1 = alpha * a22 * R2 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a22,lda,r2,		    ldr2,0.0,r1,ldr3,d_aux,i_naux);       /* c21 = c21 + R1 */       matrix_add(m_half,n_half,1.0,r1,ldr3,c21,ldc,c21,ldc);     }     else {    /* if beta != 0 */              /* Since beta != 0.0, initial values in C cannot be ignored.  Three	* additional temporary arrays are needed.	*/       /* Partition the work space d_aux into submatrices R1, R2, & R3, where	* R1 is m_half x k_half, if c_transa = N, or k_half x m_half,	* if c_transa = T,	* R2 is k_half x n_half, if c_transb = N, or n_half x k_half,	* if c_transb = T, 	* and R3 is m_half x n_half.  Set leading dimensions appropriately.  	* At the end, d_aux points to the next unused memory location.	*/       r1 = d_aux;       if (c_transa == 'T') {	 ldr1 = k_half; 	 r2 = r1 + k_half*m_half;	 rows_a = k_half;	 cols_a = m_half;       }       else {	 ldr1 = m_half;	 r2 = r1 + m_half*k_half;	 rows_a = m_half;	 cols_a = k_half;       }              if (c_transb == 'T') {	 ldr2 = n_half;	 r3 = r2 + n_half*k_half;	 rows_b = n_half;	 cols_b = k_half;       }       else {	 ldr2 = k_half;	 r3 = r2 + k_half*n_half;	 rows_b = k_half;	 cols_b = n_half;       }              ldr3 = m_half;        d_aux = r3 + m_half*n_half;              /* R1 = alpha * (a21 + a22) */       matrix_add(rows_a,cols_a,alpha,a21,lda,a22,lda,r1,ldr1);       /* R2 = b12 - b11 */       matrix_sub(rows_b,cols_b,1.0,b12,ldb,b11,ldb,r2,ldr2);       /* R3 = R1 * R2 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2,		    0.0,r3,ldr3,d_aux,i_naux);       /* c12 = R3 + beta * c12 */       matrix_lin_update(m_half,n_half,1.0,r3,ldr3,beta,c12,ldc);       /* c22 = R3 + beta * c22 */       matrix_lin_update(m_half,n_half,1.0,r3,ldr3,beta,c22,ldc);       /* R1 = R1 - (alpha * a11) */       matrix_lin_update(rows_a,cols_a,-alpha,a11,lda,1.0,r1,ldr1);       /* R2 = b22 - R2 */       matrix_sub(rows_b,cols_b,1.0,b22,ldb,r2,ldr2,r2,ldr2);       /* R3 = alpha * a11 * b11 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a11,lda,b11,		    ldb,0.0,r3,ldr3,d_aux,i_naux);       /* c11 = R3 + beta * c11 */       matrix_lin_update(m_half,n_half,1.0,r3,ldr3,beta,c11,ldc);       /* R3  = R3 + (R1 * R2) */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2,		    1.0,r3,ldr3,d_aux,i_naux);       /* c11 = alpha * (a12 * b21) + c11 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a12,lda,b21,		    ldb,1.0,c11,ldc,d_aux,i_naux);       /* R1 = alpha * a12 - R1 */       matrix_lin_update(rows_a,cols_a,alpha,a12,lda,-1.0,r1,ldr1);       /* R2 = alpha * (b21 - R2) */       matrix_sub(rows_b,cols_b,alpha,b21,ldb,r2,ldr2,r2,ldr2);       /* c12 = c12 + (R1 * b22) */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,b22,ldb,		    1.0,c12,ldc,d_aux,i_naux);       /* c12 = c12 + R3 */       matrix_add(m_half,n_half,1.0,r3,ldr3,c12,ldc,c12,ldc);       /* c21 = a22 * R2 + beta * c21 */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,a22,lda,r2,ldr2,		    beta,c21,ldc,d_aux,i_naux);       /* R1 = alpha * (a11 - a21) */       matrix_sub(rows_a,cols_a,alpha,a11,lda,a21,lda,r1,ldr1);       /* R2 = b22 - b12 */       matrix_sub(rows_b,cols_b,1.0,b22,ldb,b12,ldb,r2,ldr2);       /* R3 = R3 + (R1 * R2) */       fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2,		    1.0,r3,ldr3,d_aux,i_naux);       /* c21 = c21 + R3 */       matrix_add(m_half,n_half,1.0,r3,ldr3,c21,ldc,c21,ldc);       /* c22 = c22 + R3 */       matrix_add(m_half,n_half,1.0,r3,ldr3,c22,ldc,c22,ldc);     }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -