fmm_internal.c

来自「C++编写的高性能矩阵乘法的Stranssen算法」· C语言 代码 · 共 81 行

C
81
字号
/*============================================================================Internal Strassen (fmm) routine used in the recursion.  This one workson arbitrary size matrices.Inputs  transa,transb : characters specifying form of A or B (transpose or not)  m,n,k         : matrix dimensions  alpha,beta    : scalars  A             : m x k matrix  B             : k x n matrix  ldb,lda,ldc   : matrix leading dimensions  d_aux         : temporary work space  i_naux        : amount of temporary work spaceOutputs  C             : m x n matrix, C = alpha*A*B+beta*C============================================================================*/#include "matrix.h"#if STRAS_TIME_PARTS#include "matrix_test.h"#endifvoid fmm_internal(char c_transa,char c_transb,int m,int n,int k,double alpha,		    double *a,int lda,double *b,int ldb,double beta,double *c,		    int ldc,double *d_aux,int i_naux){  int     m_even, n_even, k_even,  /* even-sized portion of a matrix   */    r1, s1, t1;              /* extra rows/columns if they exist */  /* Exit this routine quickly if matrix dimension(s) too small */  if (no_recursion(m, n, k, alpha, beta)) {    /* Multiply small matrices using conventional matrix multiply */    matrix_prod(c_transa,c_transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);  }  else {    /* Continue with Strassen recursion */    /* If dimensions are odd, subtract 1 to get even dimension, and continue     * with recursion, accumulating contribution of last row/col at the end.      */    r1 = m%2;    m_even = m - r1;    s1 = n%2;    n_even = n - s1;    t1 = k%2;    k_even = k - t1;    /* Matrices A, B & C have the following decomposition:     *     *      |       |   |   |       |   |   |       |   |       *      |  A11  |a12|   |  B11  |b12|   |  C11  |c12|     *      |       |   | X |       |   | = |       |   |     *      |_______|___|   |_______|___|   |_______|___|     *      |  a21  |a22|   |  b21  |b22|   |  c21  |c22|     *     * where A11(m_even x k_even), B11(k_even x n_even), and C11(m_even x n_even)      * are even-sized matrices, and, if they exit, which depends on the values      * of m, n, & k, a21, b21, & c21 are row vectors, a12,b12, & c12 are     * column vectors, and a22, b22, & c22 are single elements.     *     */    /* Solve even-sized problem first:      *     *              C11 = alpha * (A11 * B11) * beta * C11     */    strassen_internal(c_transa, c_transb, m_even, n_even, k_even,		      alpha, a, lda, b, ldb, beta, c, ldc,		      d_aux, i_naux);    /* Take care of remaining row/col, if any */    if ((m_even != m) || (n_even != n) || (k_even != k))       fixup_internal(c_transa, c_transb, m, n, k, m_even, n_even, k_even,		     r1, s1, t1, alpha, a, lda, b, ldb, beta, c, ldc);  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?