fmm_internal.c
来自「C++编写的高性能矩阵乘法的Stranssen算法」· C语言 代码 · 共 81 行
C
81 行
/*============================================================================Internal Strassen (fmm) routine used in the recursion. This one workson arbitrary size matrices.Inputs transa,transb : characters specifying form of A or B (transpose or not) m,n,k : matrix dimensions alpha,beta : scalars A : m x k matrix B : k x n matrix ldb,lda,ldc : matrix leading dimensions d_aux : temporary work space i_naux : amount of temporary work spaceOutputs C : m x n matrix, C = alpha*A*B+beta*C============================================================================*/#include "matrix.h"#if STRAS_TIME_PARTS#include "matrix_test.h"#endifvoid fmm_internal(char c_transa,char c_transb,int m,int n,int k,double alpha, double *a,int lda,double *b,int ldb,double beta,double *c, int ldc,double *d_aux,int i_naux){ int m_even, n_even, k_even, /* even-sized portion of a matrix */ r1, s1, t1; /* extra rows/columns if they exist */ /* Exit this routine quickly if matrix dimension(s) too small */ if (no_recursion(m, n, k, alpha, beta)) { /* Multiply small matrices using conventional matrix multiply */ matrix_prod(c_transa,c_transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc); } else { /* Continue with Strassen recursion */ /* If dimensions are odd, subtract 1 to get even dimension, and continue * with recursion, accumulating contribution of last row/col at the end. */ r1 = m%2; m_even = m - r1; s1 = n%2; n_even = n - s1; t1 = k%2; k_even = k - t1; /* Matrices A, B & C have the following decomposition: * * | | | | | | | | | * | A11 |a12| | B11 |b12| | C11 |c12| * | | | X | | | = | | | * |_______|___| |_______|___| |_______|___| * | a21 |a22| | b21 |b22| | c21 |c22| * * where A11(m_even x k_even), B11(k_even x n_even), and C11(m_even x n_even) * are even-sized matrices, and, if they exit, which depends on the values * of m, n, & k, a21, b21, & c21 are row vectors, a12,b12, & c12 are * column vectors, and a22, b22, & c22 are single elements. * */ /* Solve even-sized problem first: * * C11 = alpha * (A11 * B11) * beta * C11 */ strassen_internal(c_transa, c_transb, m_even, n_even, k_even, alpha, a, lda, b, ldb, beta, c, ldc, d_aux, i_naux); /* Take care of remaining row/col, if any */ if ((m_even != m) || (n_even != n) || (k_even != k)) fixup_internal(c_transa, c_transb, m, n, k, m_even, n_even, k_even, r1, s1, t1, alpha, a, lda, b, ldb, beta, c, ldc); }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?