📄 strassen_internal.c
字号:
/*============================================================================Internal Strassen (fmm) routine used in the recursion. It only works on even-sized matrices.Inputs transa,transb : characters specifying form of A or B (transpose or not) m,n,k : matrix dimensions, all even alpha,beta : scalars A : m x k matrix B : k x n matrix ldb,lda.ldc : matrix leading dimensions d_aux : temporary work space i_naux : amount of temporary work spaceOutputs C : m x n matrix, C = alpha*A*B+beta*C============================================================================*/#include "matrix.h"void strassen_internal(char c_transa,char c_transb,int m,int n,int k, double alpha, double *a,int lda,double *b,int ldb, double beta,double *c,int ldc,double *d_aux,int i_naux){ double *a11, *a12, *a21, *a22, /* Quadrants of A, B & C */ *b11, *b12, *b21, *b22, *c11, *c12, *c21, *c22, *r1, *r2, *r3; /* Temporary matrices */ int ldr1, ldr2, ldr3, /* Leading dimensions of temporaries */ m_half, n_half, k_half, /* Half of matrix dimensions */ rows_a, cols_a, /* Rows/cols of half-sized matrix A */ rows_b, cols_b; /* Rows/cols of half-sized matrix B */ /* Exit this routine quickly if matrix dimension(s) too small */ if (no_recursion(m, n, k, alpha, beta)) { /* Multiply small matrices using conventional matrix multiply */ matrix_prod(c_transa,c_transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc); } else { /* Recursion */ m_half = m/2; n_half = n/2; k_half = k/2; /* Set pointers to the appropriate 4 submatrices of A, B, & C based on * the settings of c_transa and c_transb. */ submatrix(a,lda,1,1,&a11); if (c_transa == 'T') { submatrix(a,lda,1,m_half+1,&a21); submatrix(a,lda,k_half+1,1,&a12); submatrix(a,lda,k_half+1,m_half+1,&a22); } else { submatrix(a,lda,1,k_half+1,&a12); submatrix(a,lda,m_half+1,1,&a21); submatrix(a,lda,m_half+1,k_half+1,&a22); } submatrix(b,ldb,1,1,&b11); if (c_transb == 'T') { submatrix(b,ldb,1,k_half+1,&b21); submatrix(b,ldb,n_half+1,1,&b12); submatrix(b,ldb,n_half+1,k_half+1,&b22); } else { submatrix(b,ldb,1,n_half+1,&b12); submatrix(b,ldb,k_half+1,1,&b21); submatrix(b,ldb,k_half+1,n_half+1,&b22); } submatrix(c,ldc,1,1,&c11); submatrix(c,ldc,1,n_half+1,&c12); submatrix(c,ldc,m_half+1,1,&c21); submatrix(c,ldc,m_half+1,n_half+1,&c22); /* This is Winograd's variant of Strassen's algorithm. The sequence of * operations has been altered to reduce the amount of temporary storage. */ /* Transpose cases are handled automatically by multiply routine (fmm), * since we pass the appropriate operation via the character parameter. * Add/subtract routines do not use such a parameter, so the correct * row/column orientation has to be preserved in the call. */ if (beta == 0.0) { /* Since beta = 0.0, initial values in C are ignored, and we can use * the 4 quadrants of C as temporary storage. Only 2 additional * temporary arrays are needed. */ /* Partition the work space d_aux into submatrices R1 & R2, where * R1 has dimension m_half x MAX(k_half, n_half) if c_transa = N, or * MAX(k_half x m_half, m_half x n_half), if c_transa = T, * and R2 has dimension k_half x n_half, if c_transb = N, * or n_half x k_half, * if c_transb = T. Set leading dimensions appropriately. At the end, * d_aux points to the next unused memory location. */ r1 = d_aux; r2 = r1 + m_half*MAX(k_half,n_half); d_aux = r2 + k_half*n_half; if (c_transa == 'T') { ldr1 = k_half; /* R1 is used as a work array in 2 different */ ldr3 = m_half; /* places, so 2 leading dimensions are needed */ rows_a = k_half; cols_a = m_half; } else { ldr1 = ldr3 = m_half; rows_a = m_half; cols_a = k_half; } if (c_transb == 'T') { ldr2 = n_half; rows_b = n_half; cols_b = k_half; } else { ldr2 = k_half; rows_b = k_half; cols_b = n_half; } /* R1 = alpha * (a11 - a21) */ matrix_sub(rows_a,cols_a,alpha,a11,lda,a21,lda,r1,ldr1); /* R2 = b22 - b12 */ matrix_sub(rows_b,cols_b,1.0,b22,ldb,b12,ldb,r2,ldr2); /* c11 = R1 * R2 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2, 0.0,c11,ldc,d_aux,i_naux); /* R1 = alpha * (a21 + a22) */ matrix_add(rows_a,cols_a,alpha,a21,lda,a22,lda,r1,ldr1); /* R2 = b12 - b11 */ matrix_sub(rows_b,cols_b,1.0,b12,ldb,b11,ldb,r2,ldr2); /* c22 = R1 * R2 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2, 0.0,c22,ldc,d_aux,i_naux); /* R1 = R1 - (alpha * a11) */ matrix_lin_update(rows_a,cols_a,-alpha,a11,lda,1.0,r1,ldr1); /* R2 = b22 - R2 */ matrix_sub(rows_b,cols_b,1.0,b22,ldb,r2,ldr2,r2,ldr2); /* c21 = R1 * R2 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2, 0.0,c21,ldc,d_aux,i_naux); /* R1 = alpha * a12 - R1 */ matrix_lin_update(rows_a,cols_a,alpha,a12,lda,-1.0,r1,ldr1); /* R2 = b21 - R2 */ matrix_sub(rows_b,cols_b,1.0,b21,ldb,r2,ldr2,r2,ldr2); /* c12 = R1 * b22 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,b22,ldb, 0.0,c12,ldc,d_aux,i_naux); /* c12 = c12 + c22 */ matrix_add(m_half,n_half,1.0,c22,ldc,c12,ldc,c12,ldc); /* R1 = alpha * a11 * b11 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a11,lda,b11, ldb,0.0,r1,ldr3,d_aux,i_naux); /* c21 = c21 + R1 */ matrix_add(m_half,n_half,1.0,r1,ldr3,c21,ldc,c21,ldc); /* c12 = c12 + c21 */ matrix_add(m_half,n_half,1.0,c21,ldc,c12,ldc,c12,ldc); /* c21 = c21 + c11 */ matrix_add(m_half,n_half,1.0,c11,ldc,c21,ldc,c21,ldc); /* c22 = c22 + c21 */ matrix_add(m_half,n_half,1.0,c21,ldc,c22,ldc,c22,ldc); /* c11 = alpha * a12 * b21 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a12,lda,b21, ldb,0.0,c11,ldc,d_aux,i_naux); /* c11 = c11 + R1 */ matrix_add(m_half,n_half,1.0,r1,ldr3,c11,ldc,c11,ldc); /* R1 = alpha * a22 * R2 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a22,lda,r2, ldr2,0.0,r1,ldr3,d_aux,i_naux); /* c21 = c21 + R1 */ matrix_add(m_half,n_half,1.0,r1,ldr3,c21,ldc,c21,ldc); } else { /* if beta != 0 */ /* Since beta != 0.0, initial values in C cannot be ignored. Three * additional temporary arrays are needed. */ /* Partition the work space d_aux into submatrices R1, R2, & R3, where * R1 is m_half x k_half, if c_transa = N, or k_half x m_half, * if c_transa = T, * R2 is k_half x n_half, if c_transb = N, or n_half x k_half, * if c_transb = T, * and R3 is m_half x n_half. Set leading dimensions appropriately. * At the end, d_aux points to the next unused memory location. */ r1 = d_aux; if (c_transa == 'T') { ldr1 = k_half; r2 = r1 + k_half*m_half; rows_a = k_half; cols_a = m_half; } else { ldr1 = m_half; r2 = r1 + m_half*k_half; rows_a = m_half; cols_a = k_half; } if (c_transb == 'T') { ldr2 = n_half; r3 = r2 + n_half*k_half; rows_b = n_half; cols_b = k_half; } else { ldr2 = k_half; r3 = r2 + k_half*n_half; rows_b = k_half; cols_b = n_half; } ldr3 = m_half; d_aux = r3 + m_half*n_half; /* R1 = alpha * (a21 + a22) */ matrix_add(rows_a,cols_a,alpha,a21,lda,a22,lda,r1,ldr1); /* R2 = b12 - b11 */ matrix_sub(rows_b,cols_b,1.0,b12,ldb,b11,ldb,r2,ldr2); /* R3 = R1 * R2 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2, 0.0,r3,ldr3,d_aux,i_naux); /* c12 = R3 + beta * c12 */ matrix_lin_update(m_half,n_half,1.0,r3,ldr3,beta,c12,ldc); /* c22 = R3 + beta * c22 */ matrix_lin_update(m_half,n_half,1.0,r3,ldr3,beta,c22,ldc); /* R1 = R1 - (alpha * a11) */ matrix_lin_update(rows_a,cols_a,-alpha,a11,lda,1.0,r1,ldr1); /* R2 = b22 - R2 */ matrix_sub(rows_b,cols_b,1.0,b22,ldb,r2,ldr2,r2,ldr2); /* R3 = alpha * a11 * b11 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a11,lda,b11, ldb,0.0,r3,ldr3,d_aux,i_naux); /* c11 = R3 + beta * c11 */ matrix_lin_update(m_half,n_half,1.0,r3,ldr3,beta,c11,ldc); /* R3 = R3 + (R1 * R2) */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2, 1.0,r3,ldr3,d_aux,i_naux); /* c11 = alpha * (a12 * b21) + c11 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,alpha,a12,lda,b21, ldb,1.0,c11,ldc,d_aux,i_naux); /* R1 = alpha * a12 - R1 */ matrix_lin_update(rows_a,cols_a,alpha,a12,lda,-1.0,r1,ldr1); /* R2 = alpha * (b21 - R2) */ matrix_sub(rows_b,cols_b,alpha,b21,ldb,r2,ldr2,r2,ldr2); /* c12 = c12 + (R1 * b22) */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,b22,ldb, 1.0,c12,ldc,d_aux,i_naux); /* c12 = c12 + R3 */ matrix_add(m_half,n_half,1.0,r3,ldr3,c12,ldc,c12,ldc); /* c21 = a22 * R2 + beta * c21 */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,a22,lda,r2,ldr2, beta,c21,ldc,d_aux,i_naux); /* R1 = alpha * (a11 - a21) */ matrix_sub(rows_a,cols_a,alpha,a11,lda,a21,lda,r1,ldr1); /* R2 = b22 - b12 */ matrix_sub(rows_b,cols_b,1.0,b22,ldb,b12,ldb,r2,ldr2); /* R3 = R3 + (R1 * R2) */ fmm_internal(c_transa,c_transb,m_half,n_half,k_half,1.0,r1,ldr1,r2,ldr2, 1.0,r3,ldr3,d_aux,i_naux); /* c21 = c21 + R3 */ matrix_add(m_half,n_half,1.0,r3,ldr3,c21,ldc,c21,ldc); /* c22 = c22 + R3 */ matrix_add(m_half,n_half,1.0,r3,ldr3,c22,ldc,c22,ldc); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -