📄 fmm.c
字号:
/*********************************************************************** * COPYRIGHT 1996 United States Government. All rights reserved. * * THIS PROGRAM WAS IN PART PREPARED AS AN ACCOUNT OF WORK SPONSORED BY * AN AGENCY OF THE UNITED STATES GOVERNMENT. NEITHER THE UNITED STATES * GOVERNMENT NOR ANY AGENCY THEREOF, NOR ANY OF ITS CONTRACTORS OR * EMPLOYEES, MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR ASSUMES ANY * LEGAL LIABILITY OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR * USEFULNESS OF ANY INFORMATION, APPARATUS, PRODUCT, OR PROCESS * DISCLOSED, OR REPRESENTS THAT ITS USE WOULD NOT INFRINGE PRIVATELY * OWNED RIGHTS. * * THE GOVERNMENT IS GRANTED FOR ITSELF AND OTHERS ACTING ON ITS BEHALF A * PAID-UP, NONEXCLUSIVE, IRREVOCABLE WORLDWIDE LICENSE IN THIS COMPUTER * SOFTWARE TO REPRODUCE, PREPARE DERIVATIVE WORKS, AND PERFORM PUBLICLY * AND DISPLAY PUBLICLY. * * You may copy and distribute verbatim copies of the program's source * code as you receive it, in any medium, provided that you conspicuously * and appropriately publish on each copy the above disclaimer of * warranty and notice of Government license. * * You may modify your copy or copies of the program or any portion of * it, and copy and distribute such modifications providing you cause the * modified files to carry prominent notices stating that you changed the * files and the date of any change. * * Please feel free to send questions, comments, and problem reports * to prism@super.org. ************************************************************************//*============================================================================Fast matrix multiply. Main driver program. C interface.All arrays are column oriented storage (Fortran style).The extra memory needed by the routine is provided by the users or allociatedif not sufficient.Inputs c_transa_in,c_transb_in : char specifying form of A or B (transpose or not) m,n,k : matrix dimensions alpha,beta : scalars A : m x k matrix B : k x n matrix ldb,lda,ldc : matrix leading dimensions d_aux : temporary work space i_naux : amount of temporay work spaceOutputs C : m x n matrix, C = alpha*A*B+beta*C============================================================================*/#include "matrix.h"#if STRAS_TIME_PARTS#include "matrix_test.h"#endifvoid fmm(char c_transa_in, char c_transb_in, int m, int n, int k, double alpha, double *a, int lda, double *b, int ldb, double beta, double *c, int ldc, double *d_aux, int i_naux){ int m_mod, n_mod, k_mod, /* Even-sized portion of matrices */ r1, s1, t1, /* Dimension of remaining portion */ free_flag = 0, /* Memory allocation flag */ i_mod, /* Temp variable */ l_i_naux; /* Size of work space */ double *new_d_aux; /* Temp pointer for work space */ char c_transa, c_transb, /* Char for transpose or not */ err_msg[100];#if STRAS_LIB /* Needed for library version */ cutoff = STRAS_CUTOFF; in_cutoff_m = -1;#endif /* Check matrix dimensions */ if (m < 0 || n < 0 || k < 0) { sprintf(err_msg, "One or more of the matrix dimensions is less than zero: m = %d n = %d k = %d\n", m, n, k); generror(err_msg); } /* Take care of lower case character input for c_transa/c_transb */ c_transa = c_transa_in; if (c_transa_in == 't') c_transa = 'T'; else if (c_transa_in == 'n') c_transa = 'N'; c_transb = c_transb_in; if (c_transb_in == 't') c_transb = 'T'; else if (c_transb_in == 'n') c_transb = 'N'; /* Exit this routine quickly if matrix dimension(s) too small */ if (no_recursion(m, n, k, alpha, beta)) { /* Multiply small matrices using conventional matrix multiply */ matrix_prod(c_transa, c_transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } else { /* Check user arguments */ if ((c_transa != 'T') && (c_transa != 'N')) { sprintf(err_msg, "Transpose option on A (c_transa = %c) is not a T or N", c_transa); generror(err_msg); } if ((c_transb != 'T') && (c_transb != 'N')) { sprintf(err_msg, "Transpose option on B (c_transb = %c) is not a T or N", c_transb); generror(err_msg); } if (c_transa == 'T') { if (lda < k) { sprintf(err_msg, "Leading dimension of A (lda = %d) < rows given (k = %d)", lda, k); generror(err_msg); } } else { if (lda < m) { sprintf(err_msg, "Leading dimension of A (lda = %d) < rows given (m = %d)", lda, m); generror(err_msg); } } if (c_transb == 'T') { if (ldb < n) { sprintf(err_msg, "Leading dimension of B (ldb = %d) < rows given (n = %d)", ldb, n); generror(err_msg); } } else { if (ldb < k) { sprintf(err_msg, "Leading dimension of B (ldb = %d) < rows given (k = %d)", ldb, k); generror(err_msg); } } if (ldc < m) { sprintf(err_msg, "Leading dimension of C (ldc = %d) < rows given (m = %d)", ldc, m); generror(err_msg); } /* Check if user provided enough work space; if not, allocate required amount now. */ l_i_naux = tmp_fmm(m, n, k, alpha, beta); if (d_aux == NULL || i_naux < l_i_naux) { free_flag = 1; i_naux = l_i_naux; matrix_alloc(i_naux,1,&new_d_aux,&r1); } else { new_d_aux = d_aux; } /* The default mode is to have i_mod = 2 so you only peel one * row/col. Perform Strassen recursion on remaining (even-sized) * matrix. Do cleanup of left-over rows/cols. */ i_mod = 2; r1 = m % i_mod; m_mod = m - r1; s1 = n % i_mod; n_mod = n - s1; t1 = k % i_mod; k_mod = k - t1; /* Matrices A, B & C have the following decomposition: * * | | | | | | | | | * | A11 |A12| | B11 |B12| | C11 |C12| * | | | X | | | = | | | * |_______|___| |_______|___| |_______|___| * | A21 |A22| | B21 |B22| | C21 |C22| * * where A11(m_mod x k_mod), B11(k_mod x n_mod), and C11(m_mod x n_mod) are * even-sized matrices, and, if they exit, which depends on the values * of m, n, & k, the remaining matrices have the following dimensions: * * A12(m_mod x t1), A21(r1 x k_mod), A22(r1 x t1) * B12(k_mod x s1), B21(t1 x n_mod), B22(t1 x s1) * C12(m_mod x s1), C21(r1 x n_mod), C22(r1 x s1) * */ /* Solve even-sized problem first: * * C11 = alpha * (A11 * B11) * beta * C11 */ fmm_internal(c_transa, c_transb, m_mod, n_mod, k_mod, alpha, a, lda, b, ldb, beta, c, ldc, new_d_aux, i_naux); /* Free workspace, if allocated in this routine. */ if (free_flag) free_matrix(new_d_aux); /* Take care of remaining row(s)/col(s), if any */ if ((m_mod != m) || (n_mod != n) || (k_mod != k)) fixup_internal(c_transa, c_transb, m, n, k, m_mod, n_mod, k_mod, r1, s1, t1, alpha, a, lda, b, ldb, beta,c,ldc); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -