matrix_mult.hpp
来自「矩阵运算源码最新版本」· HPP 代码 · 共 1,560 行 · 第 1/4 页
HPP
1,560 行
// Software License for MTL// // Copyright (c) 2007 The Trustees of Indiana University. All rights reserved.// Authors: Peter Gottschling and Andrew Lumsdaine// // This file is part of the Matrix Template Library// // See also license.mtl.txt in the distribution.// Written mostly by Michael Adams// Modified by Peter Gottschling#ifndef MTL_OPTERON_MATRIX_MULT_INCLUDE#define MTL_OPTERON_MATRIX_MULT_INCLUDE#if defined MTL_USE_OPTERON_OPTIMIZATION && defined __GNUC__ && !defined __INTEL_COMPILER#include <boost/numeric/mtl/operation/assign_mode.hpp>#include <boost/numeric/mtl/operation/set_to_zero.hpp>#include <boost/numeric/mtl/recursion/bit_masking.hpp>namespace mtl {namespace detail { template <unsigned long MaskA, unsigned long MaskB, unsigned long MaskC> struct opteron_shark_teeth { static const unsigned long base_case_bits= 5, tooth_length = 1; static const bool value= is_k_power_base_case_row_major_t_shark<base_case_bits, tooth_length, MaskA>::value && is_k_power_base_case_col_major_t_shark<base_case_bits, tooth_length, MaskB>::value && is_k_power_base_case_row_major_t_shark<base_case_bits, tooth_length, MaskC>::value; }; template <typename Assign, unsigned long MaskA, typename PA, unsigned long MaskB, typename PB, unsigned long MaskC, typename PC> inline void opteron_shark_teeth_mult(const morton_dense<double, MaskA, PA>& a, const morton_dense<double, MaskB, PB>& b, morton_dense<double, MaskC, PC>& c) { // Never sets the matrix to zero; this is supposed to be done before if necessary typedef typename morton_dense<double, MaskA, PA>::size_type size_type; size_type i_max= c.num_rows(), i_block= 2 * (i_max / 2), j_max= c.num_cols(), j_block= 2 * (j_max / 2), k_max= a.num_cols(); const int stride= 32; double *ap= &const_cast<morton_dense<double, MaskA, PA>&>(a)[0][0], *bp= &const_cast<morton_dense<double, MaskB, PB>&>(b)[0][0], *cp= &c[0][0]; // C_nw += A_n * B_n for (size_type i= 0; i < i_block; i+=2) for (int j = 0; j < j_block; j+=2) { double tmp00= 0.0, tmp01= 0.0, tmp10= 0.0, tmp11= 0.0; for (int k = 0; k < k_max; k++) { tmp00 += ap[0+(i)*stride+2*k] * bp[0+(j)*stride+2*k]; tmp01 += ap[0+(i)*stride+2*k] * bp[1+(j)*stride+2*k]; tmp10 += ap[1+(i)*stride+2*k] * bp[0+(j)*stride+2*k]; tmp11 += ap[1+(i)*stride+2*k] * bp[1+(j)*stride+2*k]; } Assign::update(cp[0+(i)*stride+2*(j+0)], tmp00); Assign::update(cp[0+(i)*stride+2*(j+1)], tmp01); Assign::update(cp[1+(i)*stride+2*(j+0)], tmp10); Assign::update(cp[1+(i)*stride+2*(j+1)], tmp11); } // C_ne += A_n * B_e for (size_type i= 0; i < i_block; i+=2) for (int j = j_block; j < j_max; j++) { double tmp00= 0.0, tmp10= 0.0; for (int k = 0; k < k_max; k++) { tmp00 += ap[0+(i)*stride+2*k] * bp[0+(j)*stride+2*k]; tmp10 += ap[1+(i)*stride+2*k] * bp[0+(j)*stride+2*k]; } Assign::update(cp[0+(i)*stride+2*(j+0)], tmp00); Assign::update(cp[1+(i)*stride+2*(j+0)], tmp10); } // C_s += A_s * B for (size_type i= i_block; i < i_max; i++) for (int j = 0; j < j_max; j++) { double tmp00= 0.0; for (int k = 0; k < k_max; k++) tmp00 += ap[0+(i)*stride+2*k] * bp[0+(j)*stride+2*k]; Assign::update(cp[0+(i)*stride+2*(j+0)], tmp00); } }} // namespace detail// for C= AB and C+= ABtemplate <unsigned long MaskA, typename PA, unsigned long MaskB, typename PB, unsigned long MaskC, typename PC, typename Assign, typename Backup>struct gen_platform_dmat_dmat_mult_ft<morton_dense<double, MaskA, PA>, morton_dense<double, MaskB, PB>, morton_dense<double, MaskC, PC>, Assign, Backup>{ void mult_ass(double * D, double * C, double * BT) const; void operator()(const morton_dense<double, MaskA, PA>& a, const morton_dense<double, MaskB, PB>& b, morton_dense<double, MaskC, PC>& c) const { // std::cout << "use Assembly\n"; if (detail::opteron_shark_teeth<MaskA, MaskB, MaskC>::value) { if (Assign::init_to_zero) set_to_zero(c); if (a.num_rows() == 32 && a.num_cols() == 32 && b.num_cols() == 32) { double *ap= const_cast<morton_dense<double, MaskA, PA>&>(a).elements(), *bp= const_cast<morton_dense<double, MaskB, PB>&>(b).elements(), cp= &c.elements(); mult_ass(cp, ap, bp); } else detail::opteron_shark_teeth_mult<Assign>(a, b, c); return; } Backup()(a, b, c); }};// for C-= ABtemplate <unsigned long MaskA, typename PA, unsigned long MaskB, typename PB, unsigned long MaskC, typename PC, typename Backup>struct gen_platform_dmat_dmat_mult_ft<morton_dense<double, MaskA, PA>, morton_dense<double, MaskB, PB>, morton_dense<double, MaskC, PC>, assign::minus_sum, Backup>{ void mult_ass(double * D, double * C, double * BT) const; void operator()(const morton_dense<double, MaskA, PA>& a, const morton_dense<double, MaskB, PB>& b, morton_dense<double, MaskC, PC>& c) const { // std::cout << "use Assembly\n"; if (detail::opteron_shark_teeth<MaskA, MaskB, MaskC>::value) { if (a.num_rows() == 32 && a.num_cols() == 32 && b.num_cols() == 32) { double ap= &const_cast<morton_dense<double, MaskA, PA>&>(a).elements(), bp= &const_cast<morton_dense<double, MaskB, PB>&>(b).elements(), cp= &c.elements(); mult_ass(cp, ap, bp); } else detail::opteron_shark_teeth_mult<assign::minus_sum>(a, b, c); return; } Backup()(a, b, c); }};template <unsigned long MaskA, typename PA, unsigned long MaskB, typename PB, unsigned long MaskC, typename PC, typename Assign, typename Backup>void gen_platform_dmat_dmat_mult_ft<morton_dense<double, MaskA, PA>, morton_dense<double, MaskB, PB>, morton_dense<double, MaskC, PC>, Assign, Backup>::mult_ass(double * D, double * C, double * BT) const{ // std::cout << "in Assembly\n"; const int baseOrder= 32, stride = baseOrder; /* double * restrict D = aa + ((d*baseSize)&(rowMask|colMask)); double * restrict C = aa + ((c*baseSize)&(rowMask|colMask)); double * restrict BT = aa + ((d*baseSize)&colMask)/2 + ((c*baseSize)&colMask); */ #if 0 for (int i = 0; i < baseOrder; i+=2) for (int j = 0; j < baseOrder; j+=2) for (int k = 0; k < baseOrder; k++) { D[0+(i)*stride+2*(j+0)] += C[0+(i)*stride+2*k] * BT[0+(j)*stride+2*k]; D[0+(i)*stride+2*(j+1)] += C[0+(i)*stride+2*k] * BT[1+(j)*stride+2*k]; D[1+(i)*stride+2*(j+0)] += C[1+(i)*stride+2*k] * BT[0+(j)*stride+2*k]; D[1+(i)*stride+2*(j+1)] += C[1+(i)*stride+2*k] * BT[1+(j)*stride+2*k]; } #endif #if 0 // Reorder loops (Ordering based on where the target code is going). for (int j = 0; j < baseOrder; j+=2) for (int i = 0; i < baseOrder; i+=16) { for (int k = 0; k < baseOrder; k++) { for (int i2 = i; i2 < i+16; i2+=2) { D[0+(i2)*stride+2*(j+0)] += C[0+(i2)*stride+2*k] * BT[0+(j)*stride+2*k]; D[1+(i2)*stride+2*(j+0)] += C[1+(i2)*stride+2*k] * BT[0+(j)*stride+2*k]; } } for (int k = 0; k < baseOrder; k++) { for (int i2 = i; i2 < i+16; i2+=2) { D[0+(i2)*stride+2*(j+1)] += C[0+(i2)*stride+2*k] * BT[1+(j)*stride+2*k]; D[1+(i2)*stride+2*(j+1)] += C[1+(i2)*stride+2*k] * BT[1+(j)*stride+2*k]; } } } #endif #if 0 // Unroll i2 for (int j = 0; j < baseOrder; j+=2) for (int i = 0; i < baseOrder; i+=16) { for (int k = 0; k < baseOrder; k++) { D[0+(i+ 0)*stride+2*(j+0)]+=C[0+(i+ 0)*stride+2*k]*BT[0+j*stride+2*k]; D[1+(i+ 0)*stride+2*(j+0)]+=C[1+(i+ 0)*stride+2*k]*BT[0+j*stride+2*k]; D[0+(i+ 2)*stride+2*(j+0)]+=C[0+(i+ 2)*stride+2*k]*BT[0+j*stride+2*k]; D[1+(i+ 2)*stride+2*(j+0)]+=C[1+(i+ 2)*stride+2*k]*BT[0+j*stride+2*k]; D[0+(i+ 4)*stride+2*(j+0)]+=C[0+(i+ 4)*stride+2*k]*BT[0+j*stride+2*k]; D[1+(i+ 4)*stride+2*(j+0)]+=C[1+(i+ 4)*stride+2*k]*BT[0+j*stride+2*k]; D[0+(i+ 6)*stride+2*(j+0)]+=C[0+(i+ 6)*stride+2*k]*BT[0+j*stride+2*k]; D[1+(i+ 6)*stride+2*(j+0)]+=C[1+(i+ 6)*stride+2*k]*BT[0+j*stride+2*k]; D[0+(i+ 8)*stride+2*(j+0)]+=C[0+(i+ 8)*stride+2*k]*BT[0+j*stride+2*k]; D[1+(i+ 8)*stride+2*(j+0)]+=C[1+(i+ 8)*stride+2*k]*BT[0+j*stride+2*k]; D[0+(i+10)*stride+2*(j+0)]+=C[0+(i+10)*stride+2*k]*BT[0+j*stride+2*k]; D[1+(i+10)*stride+2*(j+0)]+=C[1+(i+10)*stride+2*k]*BT[0+j*stride+2*k]; D[0+(i+12)*stride+2*(j+0)]+=C[0+(i+12)*stride+2*k]*BT[0+j*stride+2*k]; D[1+(i+12)*stride+2*(j+0)]+=C[1+(i+12)*stride+2*k]*BT[0+j*stride+2*k]; D[0+(i+14)*stride+2*(j+0)]+=C[0+(i+14)*stride+2*k]*BT[0+j*stride+2*k]; D[1+(i+14)*stride+2*(j+0)]+=C[1+(i+14)*stride+2*k]*BT[0+j*stride+2*k]; } for (int k = 0; k < baseOrder; k++) { D[0+(i+ 0)*stride+2*(j+1)]+=C[0+(i+ 0)*stride+2*k]*BT[1+j*stride+2*k]; D[1+(i+ 0)*stride+2*(j+1)]+=C[1+(i+ 0)*stride+2*k]*BT[1+j*stride+2*k]; D[0+(i+ 2)*stride+2*(j+1)]+=C[0+(i+ 2)*stride+2*k]*BT[1+j*stride+2*k]; D[1+(i+ 2)*stride+2*(j+1)]+=C[1+(i+ 2)*stride+2*k]*BT[1+j*stride+2*k]; D[0+(i+ 4)*stride+2*(j+1)]+=C[0+(i+ 4)*stride+2*k]*BT[1+j*stride+2*k]; D[1+(i+ 4)*stride+2*(j+1)]+=C[1+(i+ 4)*stride+2*k]*BT[1+j*stride+2*k]; D[0+(i+ 6)*stride+2*(j+1)]+=C[0+(i+ 6)*stride+2*k]*BT[1+j*stride+2*k]; D[1+(i+ 6)*stride+2*(j+1)]+=C[1+(i+ 6)*stride+2*k]*BT[1+j*stride+2*k]; D[0+(i+ 8)*stride+2*(j+1)]+=C[0+(i+ 8)*stride+2*k]*BT[1+j*stride+2*k]; D[1+(i+ 8)*stride+2*(j+1)]+=C[1+(i+ 8)*stride+2*k]*BT[1+j*stride+2*k]; D[0+(i+10)*stride+2*(j+1)]+=C[0+(i+10)*stride+2*k]*BT[1+j*stride+2*k]; D[1+(i+10)*stride+2*(j+1)]+=C[1+(i+10)*stride+2*k]*BT[1+j*stride+2*k]; D[0+(i+12)*stride+2*(j+1)]+=C[0+(i+12)*stride+2*k]*BT[1+j*stride+2*k]; D[1+(i+12)*stride+2*(j+1)]+=C[1+(i+12)*stride+2*k]*BT[1+j*stride+2*k]; D[0+(i+14)*stride+2*(j+1)]+=C[0+(i+14)*stride+2*k]*BT[1+j*stride+2*k]; D[1+(i+14)*stride+2*(j+1)]+=C[1+(i+14)*stride+2*k]*BT[1+j*stride+2*k]; } } #endif #if 0 // Prep k for (int j = 0; j < baseOrder; j+=2) for (int i = 0; i < baseOrder; i+=16) { { double d00 = D[0+(i+ 0)*stride+2*(j+0)]; double d01 = D[1+(i+ 0)*stride+2*(j+0)]; double d02 = D[0+(i+ 2)*stride+2*(j+0)]; double d03 = D[1+(i+ 2)*stride+2*(j+0)]; double d04 = D[0+(i+ 4)*stride+2*(j+0)]; double d05 = D[1+(i+ 4)*stride+2*(j+0)]; double d06 = D[0+(i+ 6)*stride+2*(j+0)]; double d07 = D[1+(i+ 6)*stride+2*(j+0)]; double d08 = D[0+(i+ 8)*stride+2*(j+0)]; double d09 = D[1+(i+ 8)*stride+2*(j+0)]; double d10 = D[0+(i+10)*stride+2*(j+0)]; double d11 = D[1+(i+10)*stride+2*(j+0)]; double d12 = D[0+(i+12)*stride+2*(j+0)]; double d13 = D[1+(i+12)*stride+2*(j+0)]; double d14 = D[0+(i+14)*stride+2*(j+0)]; double d15 = D[1+(i+14)*stride+2*(j+0)]; for (int k = 0; k < baseOrder; k++) { d00+=C[0+(i+ 0)*stride+2*k]*BT[0+j*stride+2*k]; d01+=C[1+(i+ 0)*stride+2*k]*BT[0+j*stride+2*k]; d02+=C[0+(i+ 2)*stride+2*k]*BT[0+j*stride+2*k]; d03+=C[1+(i+ 2)*stride+2*k]*BT[0+j*stride+2*k]; d04+=C[0+(i+ 4)*stride+2*k]*BT[0+j*stride+2*k]; d05+=C[1+(i+ 4)*stride+2*k]*BT[0+j*stride+2*k]; d06+=C[0+(i+ 6)*stride+2*k]*BT[0+j*stride+2*k]; d07+=C[1+(i+ 6)*stride+2*k]*BT[0+j*stride+2*k]; d08+=C[0+(i+ 8)*stride+2*k]*BT[0+j*stride+2*k]; d09+=C[1+(i+ 8)*stride+2*k]*BT[0+j*stride+2*k]; d10+=C[0+(i+10)*stride+2*k]*BT[0+j*stride+2*k]; d11+=C[1+(i+10)*stride+2*k]*BT[0+j*stride+2*k]; d12+=C[0+(i+12)*stride+2*k]*BT[0+j*stride+2*k]; d13+=C[1+(i+12)*stride+2*k]*BT[0+j*stride+2*k]; d14+=C[0+(i+14)*stride+2*k]*BT[0+j*stride+2*k]; d15+=C[1+(i+14)*stride+2*k]*BT[0+j*stride+2*k]; } D[0+(i+ 0)*stride+2*(j+0)] = d00; D[1+(i+ 0)*stride+2*(j+0)] = d01; D[0+(i+ 2)*stride+2*(j+0)] = d02; D[1+(i+ 2)*stride+2*(j+0)] = d03; D[0+(i+ 4)*stride+2*(j+0)] = d04; D[1+(i+ 4)*stride+2*(j+0)] = d05; D[0+(i+ 6)*stride+2*(j+0)] = d06; D[1+(i+ 6)*stride+2*(j+0)] = d07; D[0+(i+ 8)*stride+2*(j+0)] = d08; D[1+(i+ 8)*stride+2*(j+0)] = d09; D[0+(i+10)*stride+2*(j+0)] = d10; D[1+(i+10)*stride+2*(j+0)] = d11; D[0+(i+12)*stride+2*(j+0)] = d12; D[1+(i+12)*stride+2*(j+0)] = d13; D[0+(i+14)*stride+2*(j+0)] = d14; D[1+(i+14)*stride+2*(j+0)] = d15; } { double d00 = D[0+(i+ 0)*stride+2*(j+1)]; double d01 = D[1+(i+ 0)*stride+2*(j+1)]; double d02 = D[0+(i+ 2)*stride+2*(j+1)]; double d03 = D[1+(i+ 2)*stride+2*(j+1)]; double d04 = D[0+(i+ 4)*stride+2*(j+1)]; double d05 = D[1+(i+ 4)*stride+2*(j+1)]; double d06 = D[0+(i+ 6)*stride+2*(j+1)]; double d07 = D[1+(i+ 6)*stride+2*(j+1)]; double d08 = D[0+(i+ 8)*stride+2*(j+1)]; double d09 = D[1+(i+ 8)*stride+2*(j+1)]; double d10 = D[0+(i+10)*stride+2*(j+1)]; double d11 = D[1+(i+10)*stride+2*(j+1)]; double d12 = D[0+(i+12)*stride+2*(j+1)]; double d13 = D[1+(i+12)*stride+2*(j+1)]; double d14 = D[0+(i+14)*stride+2*(j+1)]; double d15 = D[1+(i+14)*stride+2*(j+1)]; for (int k = 0; k < baseOrder; k++) { d00+=C[0+(i+ 0)*stride+2*k]*BT[1+j*stride+2*k]; d01+=C[1+(i+ 0)*stride+2*k]*BT[1+j*stride+2*k]; d02+=C[0+(i+ 2)*stride+2*k]*BT[1+j*stride+2*k]; d03+=C[1+(i+ 2)*stride+2*k]*BT[1+j*stride+2*k]; d04+=C[0+(i+ 4)*stride+2*k]*BT[1+j*stride+2*k]; d05+=C[1+(i+ 4)*stride+2*k]*BT[1+j*stride+2*k]; d06+=C[0+(i+ 6)*stride+2*k]*BT[1+j*stride+2*k]; d07+=C[1+(i+ 6)*stride+2*k]*BT[1+j*stride+2*k]; d08+=C[0+(i+ 8)*stride+2*k]*BT[1+j*stride+2*k]; d09+=C[1+(i+ 8)*stride+2*k]*BT[1+j*stride+2*k]; d10+=C[0+(i+10)*stride+2*k]*BT[1+j*stride+2*k]; d11+=C[1+(i+10)*stride+2*k]*BT[1+j*stride+2*k]; d12+=C[0+(i+12)*stride+2*k]*BT[1+j*stride+2*k]; d13+=C[1+(i+12)*stride+2*k]*BT[1+j*stride+2*k]; d14+=C[0+(i+14)*stride+2*k]*BT[1+j*stride+2*k]; d15+=C[1+(i+14)*stride+2*k]*BT[1+j*stride+2*k]; } D[0+(i+ 0)*stride+2*(j+1)] = d00; D[1+(i+ 0)*stride+2*(j+1)] = d01; D[0+(i+ 2)*stride+2*(j+1)] = d02; D[1+(i+ 2)*stride+2*(j+1)] = d03; D[0+(i+ 4)*stride+2*(j+1)] = d04; D[1+(i+ 4)*stride+2*(j+1)] = d05; D[0+(i+ 6)*stride+2*(j+1)] = d06; D[1+(i+ 6)*stride+2*(j+1)] = d07; D[0+(i+ 8)*stride+2*(j+1)] = d08; D[1+(i+ 8)*stride+2*(j+1)] = d09; D[0+(i+10)*stride+2*(j+1)] = d10; D[1+(i+10)*stride+2*(j+1)] = d11; D[0+(i+12)*stride+2*(j+1)] = d12; D[1+(i+12)*stride+2*(j+1)] = d13; D[0+(i+14)*stride+2*(j+1)] = d14; D[1+(i+14)*stride+2*(j+1)] = d15; } } #endif #if 0 // Begin SSE for (int j = 0; j < baseOrder; j+=2) for (int i = 0; i < baseOrder; i+=16) { { __m128d d00 = _mm_load_pd(&D[0+(i+ 0)*stride+2*(j+0)]); __m128d d02 = _mm_load_pd(&D[0+(i+ 2)*stride+2*(j+0)]); __m128d d04 = _mm_load_pd(&D[0+(i+ 4)*stride+2*(j+0)]); __m128d d06 = _mm_load_pd(&D[0+(i+ 6)*stride+2*(j+0)]); __m128d d08 = _mm_load_pd(&D[0+(i+ 8)*stride+2*(j+0)]); __m128d d10 = _mm_load_pd(&D[0+(i+10)*stride+2*(j+0)]); __m128d d12 = _mm_load_pd(&D[0+(i+12)*stride+2*(j+0)]); __m128d d14 = _mm_load_pd(&D[0+(i+14)*stride+2*(j+0)]); for (int k = 0; k < baseOrder; k++) { __m128d bt0 = _mm_load1_pd(&BT[0+j*stride+2*k]); d00+=_mm_load_pd(&C[0+(i+ 0)*stride+2*k])*bt0;
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?