matrix_mult.hpp

来自「矩阵运算源码最新版本」· HPP 代码 · 共 1,560 行 · 第 1/4 页

HPP
1,560
字号
// Software License for MTL// // Copyright (c) 2007 The Trustees of Indiana University. All rights reserved.// Authors: Peter Gottschling and Andrew Lumsdaine// // This file is part of the Matrix Template Library// // See also license.mtl.txt in the distribution.// Written mostly by Michael Adams// Modified by Peter Gottschling#ifndef MTL_OPTERON_MATRIX_MULT_INCLUDE#define MTL_OPTERON_MATRIX_MULT_INCLUDE#if defined MTL_USE_OPTERON_OPTIMIZATION && defined __GNUC__ && !defined __INTEL_COMPILER#include <boost/numeric/mtl/operation/assign_mode.hpp>#include <boost/numeric/mtl/operation/set_to_zero.hpp>#include <boost/numeric/mtl/recursion/bit_masking.hpp>namespace mtl {namespace detail {    template <unsigned long MaskA, unsigned long MaskB, unsigned long MaskC>    struct opteron_shark_teeth    {	static const unsigned long base_case_bits= 5, tooth_length = 1;        static const bool value= is_k_power_base_case_row_major_t_shark<base_case_bits, tooth_length, MaskA>::value	                         && is_k_power_base_case_col_major_t_shark<base_case_bits, tooth_length, MaskB>::value	                         && is_k_power_base_case_row_major_t_shark<base_case_bits, tooth_length, MaskC>::value;    };    template <typename Assign, unsigned long MaskA, typename PA,	      unsigned long MaskB, typename PB,	      unsigned long MaskC, typename PC>    inline void     opteron_shark_teeth_mult(const morton_dense<double, MaskA, PA>& a, const morton_dense<double, MaskB, PB>& b, 			     morton_dense<double, MaskC, PC>& c)    {	// Never sets the matrix to zero; this is supposed to be done before if necessary	typedef typename morton_dense<double, MaskA, PA>::size_type  size_type;	size_type i_max= c.num_rows(), i_block= 2 * (i_max / 2),	          j_max= c.num_cols(), j_block= 2 * (j_max / 2),	          k_max= a.num_cols();	const int stride= 32;	double *ap= &const_cast<morton_dense<double, MaskA, PA>&>(a)[0][0],               *bp= &const_cast<morton_dense<double, MaskB, PB>&>(b)[0][0], *cp= &c[0][0];	// C_nw += A_n * B_n	for (size_type i= 0; i < i_block; i+=2)	    for (int j = 0; j < j_block; j+=2) {		double tmp00= 0.0, tmp01= 0.0, tmp10= 0.0, tmp11= 0.0;		for (int k = 0; k < k_max; k++) {		    tmp00 += ap[0+(i)*stride+2*k] * bp[0+(j)*stride+2*k];		    tmp01 += ap[0+(i)*stride+2*k] * bp[1+(j)*stride+2*k];		    tmp10 += ap[1+(i)*stride+2*k] * bp[0+(j)*stride+2*k];		    tmp11 += ap[1+(i)*stride+2*k] * bp[1+(j)*stride+2*k];		}		Assign::update(cp[0+(i)*stride+2*(j+0)], tmp00);		Assign::update(cp[0+(i)*stride+2*(j+1)], tmp01);		Assign::update(cp[1+(i)*stride+2*(j+0)], tmp10);		Assign::update(cp[1+(i)*stride+2*(j+1)], tmp11);	    }	// C_ne += A_n * B_e	for (size_type i= 0; i < i_block; i+=2)	    for (int j = j_block; j < j_max; j++) {		double tmp00= 0.0, tmp10= 0.0;		for (int k = 0; k < k_max; k++) {		    tmp00 += ap[0+(i)*stride+2*k] * bp[0+(j)*stride+2*k];		    tmp10 += ap[1+(i)*stride+2*k] * bp[0+(j)*stride+2*k];		}		Assign::update(cp[0+(i)*stride+2*(j+0)], tmp00);		Assign::update(cp[1+(i)*stride+2*(j+0)], tmp10);	    }	// C_s += A_s * B	for (size_type i= i_block; i < i_max; i++)	    for (int j = 0; j < j_max; j++) {		double tmp00= 0.0;		for (int k = 0; k < k_max; k++) 		    tmp00 += ap[0+(i)*stride+2*k] * bp[0+(j)*stride+2*k];		Assign::update(cp[0+(i)*stride+2*(j+0)], tmp00);	    }    }} // namespace detail// for C= AB and C+= ABtemplate <unsigned long MaskA, typename PA,	  unsigned long MaskB, typename PB,	  unsigned long MaskC, typename PC,	  typename Assign, typename Backup>struct gen_platform_dmat_dmat_mult_ft<morton_dense<double, MaskA, PA>, morton_dense<double, MaskB, PB>, 					 morton_dense<double, MaskC, PC>, Assign, Backup>{    void mult_ass(double * D, double * C, double * BT) const;    void operator()(const morton_dense<double, MaskA, PA>& a, const morton_dense<double, MaskB, PB>& b, 		    morton_dense<double, MaskC, PC>& c) const    {	// std::cout << "use Assembly\n";	if (detail::opteron_shark_teeth<MaskA, MaskB, MaskC>::value) {	    if (Assign::init_to_zero) 		set_to_zero(c);	    if (a.num_rows() == 32 && a.num_cols() == 32 && b.num_cols() == 32) {		double *ap= const_cast<morton_dense<double, MaskA, PA>&>(a).elements(),		       *bp= const_cast<morton_dense<double, MaskB, PB>&>(b).elements(), cp= &c.elements();		mult_ass(cp, ap, bp);	    } else 		detail::opteron_shark_teeth_mult<Assign>(a, b, c);	    return;	}	Backup()(a, b, c);    }};// for C-= ABtemplate <unsigned long MaskA, typename PA,	  unsigned long MaskB, typename PB,	  unsigned long MaskC, typename PC,	  typename Backup>struct gen_platform_dmat_dmat_mult_ft<morton_dense<double, MaskA, PA>, morton_dense<double, MaskB, PB>, 					 morton_dense<double, MaskC, PC>, assign::minus_sum, Backup>{    void mult_ass(double * D, double * C, double * BT) const;    void operator()(const morton_dense<double, MaskA, PA>& a, const morton_dense<double, MaskB, PB>& b, 		    morton_dense<double, MaskC, PC>& c) const    {	// std::cout << "use Assembly\n";	if (detail::opteron_shark_teeth<MaskA, MaskB, MaskC>::value) {	    if (a.num_rows() == 32 && a.num_cols() == 32 && b.num_cols() == 32) {		double ap= &const_cast<morton_dense<double, MaskA, PA>&>(a).elements(),		       bp= &const_cast<morton_dense<double, MaskB, PB>&>(b).elements(), cp= &c.elements();		mult_ass(cp, ap, bp);	    } else 		detail::opteron_shark_teeth_mult<assign::minus_sum>(a, b, c);	    return;	}	Backup()(a, b, c);    }};template <unsigned long MaskA, typename PA,	  unsigned long MaskB, typename PB,	  unsigned long MaskC, typename PC,	  typename Assign, typename Backup>void gen_platform_dmat_dmat_mult_ft<morton_dense<double, MaskA, PA>, morton_dense<double, MaskB, PB>, 				       morton_dense<double, MaskC, PC>, Assign, Backup>::mult_ass(double * D, double * C, double * BT) const{    // std::cout << "in Assembly\n";    const int baseOrder= 32,              stride = baseOrder;     /*    double * restrict D  =  aa + ((d*baseSize)&(rowMask|colMask));    double * restrict C  =  aa + ((c*baseSize)&(rowMask|colMask));    double * restrict BT =  aa + ((d*baseSize)&colMask)/2      + ((c*baseSize)&colMask);    */  #if 0    for (int i = 0; i < baseOrder; i+=2)      for (int j = 0; j < baseOrder; j+=2)        for (int k = 0; k < baseOrder; k++)        {  	D[0+(i)*stride+2*(j+0)] += C[0+(i)*stride+2*k] * BT[0+(j)*stride+2*k];  	D[0+(i)*stride+2*(j+1)] += C[0+(i)*stride+2*k] * BT[1+(j)*stride+2*k];  	D[1+(i)*stride+2*(j+0)] += C[1+(i)*stride+2*k] * BT[0+(j)*stride+2*k];  	D[1+(i)*stride+2*(j+1)] += C[1+(i)*stride+2*k] * BT[1+(j)*stride+2*k];        }  #endif  #if 0    // Reorder loops (Ordering based on where the target code is going).    for (int j = 0; j < baseOrder; j+=2)      for (int i = 0; i < baseOrder; i+=16)      {        for (int k = 0; k < baseOrder; k++)        {          for (int i2 = i; i2 < i+16; i2+=2)	    {	      D[0+(i2)*stride+2*(j+0)] += C[0+(i2)*stride+2*k] * BT[0+(j)*stride+2*k];	      D[1+(i2)*stride+2*(j+0)] += C[1+(i2)*stride+2*k] * BT[0+(j)*stride+2*k];	    }        }        for (int k = 0; k < baseOrder; k++)        {          for (int i2 = i; i2 < i+16; i2+=2)	    {	      D[0+(i2)*stride+2*(j+1)] += C[0+(i2)*stride+2*k] * BT[1+(j)*stride+2*k];	      D[1+(i2)*stride+2*(j+1)] += C[1+(i2)*stride+2*k] * BT[1+(j)*stride+2*k];	    }        }      }  #endif  #if 0    // Unroll i2    for (int j = 0; j < baseOrder; j+=2)      for (int i = 0; i < baseOrder; i+=16)      {        for (int k = 0; k < baseOrder; k++)        {          D[0+(i+ 0)*stride+2*(j+0)]+=C[0+(i+ 0)*stride+2*k]*BT[0+j*stride+2*k];          D[1+(i+ 0)*stride+2*(j+0)]+=C[1+(i+ 0)*stride+2*k]*BT[0+j*stride+2*k];          D[0+(i+ 2)*stride+2*(j+0)]+=C[0+(i+ 2)*stride+2*k]*BT[0+j*stride+2*k];          D[1+(i+ 2)*stride+2*(j+0)]+=C[1+(i+ 2)*stride+2*k]*BT[0+j*stride+2*k];          D[0+(i+ 4)*stride+2*(j+0)]+=C[0+(i+ 4)*stride+2*k]*BT[0+j*stride+2*k];          D[1+(i+ 4)*stride+2*(j+0)]+=C[1+(i+ 4)*stride+2*k]*BT[0+j*stride+2*k];          D[0+(i+ 6)*stride+2*(j+0)]+=C[0+(i+ 6)*stride+2*k]*BT[0+j*stride+2*k];          D[1+(i+ 6)*stride+2*(j+0)]+=C[1+(i+ 6)*stride+2*k]*BT[0+j*stride+2*k];          D[0+(i+ 8)*stride+2*(j+0)]+=C[0+(i+ 8)*stride+2*k]*BT[0+j*stride+2*k];          D[1+(i+ 8)*stride+2*(j+0)]+=C[1+(i+ 8)*stride+2*k]*BT[0+j*stride+2*k];          D[0+(i+10)*stride+2*(j+0)]+=C[0+(i+10)*stride+2*k]*BT[0+j*stride+2*k];          D[1+(i+10)*stride+2*(j+0)]+=C[1+(i+10)*stride+2*k]*BT[0+j*stride+2*k];          D[0+(i+12)*stride+2*(j+0)]+=C[0+(i+12)*stride+2*k]*BT[0+j*stride+2*k];          D[1+(i+12)*stride+2*(j+0)]+=C[1+(i+12)*stride+2*k]*BT[0+j*stride+2*k];          D[0+(i+14)*stride+2*(j+0)]+=C[0+(i+14)*stride+2*k]*BT[0+j*stride+2*k];          D[1+(i+14)*stride+2*(j+0)]+=C[1+(i+14)*stride+2*k]*BT[0+j*stride+2*k];        }        for (int k = 0; k < baseOrder; k++)        {          D[0+(i+ 0)*stride+2*(j+1)]+=C[0+(i+ 0)*stride+2*k]*BT[1+j*stride+2*k];          D[1+(i+ 0)*stride+2*(j+1)]+=C[1+(i+ 0)*stride+2*k]*BT[1+j*stride+2*k];          D[0+(i+ 2)*stride+2*(j+1)]+=C[0+(i+ 2)*stride+2*k]*BT[1+j*stride+2*k];          D[1+(i+ 2)*stride+2*(j+1)]+=C[1+(i+ 2)*stride+2*k]*BT[1+j*stride+2*k];          D[0+(i+ 4)*stride+2*(j+1)]+=C[0+(i+ 4)*stride+2*k]*BT[1+j*stride+2*k];          D[1+(i+ 4)*stride+2*(j+1)]+=C[1+(i+ 4)*stride+2*k]*BT[1+j*stride+2*k];          D[0+(i+ 6)*stride+2*(j+1)]+=C[0+(i+ 6)*stride+2*k]*BT[1+j*stride+2*k];          D[1+(i+ 6)*stride+2*(j+1)]+=C[1+(i+ 6)*stride+2*k]*BT[1+j*stride+2*k];          D[0+(i+ 8)*stride+2*(j+1)]+=C[0+(i+ 8)*stride+2*k]*BT[1+j*stride+2*k];          D[1+(i+ 8)*stride+2*(j+1)]+=C[1+(i+ 8)*stride+2*k]*BT[1+j*stride+2*k];          D[0+(i+10)*stride+2*(j+1)]+=C[0+(i+10)*stride+2*k]*BT[1+j*stride+2*k];          D[1+(i+10)*stride+2*(j+1)]+=C[1+(i+10)*stride+2*k]*BT[1+j*stride+2*k];          D[0+(i+12)*stride+2*(j+1)]+=C[0+(i+12)*stride+2*k]*BT[1+j*stride+2*k];          D[1+(i+12)*stride+2*(j+1)]+=C[1+(i+12)*stride+2*k]*BT[1+j*stride+2*k];          D[0+(i+14)*stride+2*(j+1)]+=C[0+(i+14)*stride+2*k]*BT[1+j*stride+2*k];          D[1+(i+14)*stride+2*(j+1)]+=C[1+(i+14)*stride+2*k]*BT[1+j*stride+2*k];        }      }  #endif  #if 0    // Prep k    for (int j = 0; j < baseOrder; j+=2)      for (int i = 0; i < baseOrder; i+=16)      {        {          double d00 = D[0+(i+ 0)*stride+2*(j+0)];          double d01 = D[1+(i+ 0)*stride+2*(j+0)];          double d02 = D[0+(i+ 2)*stride+2*(j+0)];          double d03 = D[1+(i+ 2)*stride+2*(j+0)];          double d04 = D[0+(i+ 4)*stride+2*(j+0)];          double d05 = D[1+(i+ 4)*stride+2*(j+0)];          double d06 = D[0+(i+ 6)*stride+2*(j+0)];          double d07 = D[1+(i+ 6)*stride+2*(j+0)];          double d08 = D[0+(i+ 8)*stride+2*(j+0)];          double d09 = D[1+(i+ 8)*stride+2*(j+0)];          double d10 = D[0+(i+10)*stride+2*(j+0)];          double d11 = D[1+(i+10)*stride+2*(j+0)];          double d12 = D[0+(i+12)*stride+2*(j+0)];          double d13 = D[1+(i+12)*stride+2*(j+0)];          double d14 = D[0+(i+14)*stride+2*(j+0)];          double d15 = D[1+(i+14)*stride+2*(j+0)];        for (int k = 0; k < baseOrder; k++)        {          d00+=C[0+(i+ 0)*stride+2*k]*BT[0+j*stride+2*k];          d01+=C[1+(i+ 0)*stride+2*k]*BT[0+j*stride+2*k];          d02+=C[0+(i+ 2)*stride+2*k]*BT[0+j*stride+2*k];          d03+=C[1+(i+ 2)*stride+2*k]*BT[0+j*stride+2*k];          d04+=C[0+(i+ 4)*stride+2*k]*BT[0+j*stride+2*k];          d05+=C[1+(i+ 4)*stride+2*k]*BT[0+j*stride+2*k];          d06+=C[0+(i+ 6)*stride+2*k]*BT[0+j*stride+2*k];          d07+=C[1+(i+ 6)*stride+2*k]*BT[0+j*stride+2*k];          d08+=C[0+(i+ 8)*stride+2*k]*BT[0+j*stride+2*k];          d09+=C[1+(i+ 8)*stride+2*k]*BT[0+j*stride+2*k];          d10+=C[0+(i+10)*stride+2*k]*BT[0+j*stride+2*k];          d11+=C[1+(i+10)*stride+2*k]*BT[0+j*stride+2*k];          d12+=C[0+(i+12)*stride+2*k]*BT[0+j*stride+2*k];          d13+=C[1+(i+12)*stride+2*k]*BT[0+j*stride+2*k];          d14+=C[0+(i+14)*stride+2*k]*BT[0+j*stride+2*k];          d15+=C[1+(i+14)*stride+2*k]*BT[0+j*stride+2*k];        }          D[0+(i+ 0)*stride+2*(j+0)] = d00;          D[1+(i+ 0)*stride+2*(j+0)] = d01;          D[0+(i+ 2)*stride+2*(j+0)] = d02;          D[1+(i+ 2)*stride+2*(j+0)] = d03;          D[0+(i+ 4)*stride+2*(j+0)] = d04;          D[1+(i+ 4)*stride+2*(j+0)] = d05;          D[0+(i+ 6)*stride+2*(j+0)] = d06;          D[1+(i+ 6)*stride+2*(j+0)] = d07;          D[0+(i+ 8)*stride+2*(j+0)] = d08;          D[1+(i+ 8)*stride+2*(j+0)] = d09;          D[0+(i+10)*stride+2*(j+0)] = d10;          D[1+(i+10)*stride+2*(j+0)] = d11;          D[0+(i+12)*stride+2*(j+0)] = d12;          D[1+(i+12)*stride+2*(j+0)] = d13;          D[0+(i+14)*stride+2*(j+0)] = d14;          D[1+(i+14)*stride+2*(j+0)] = d15;        }        {          double d00 = D[0+(i+ 0)*stride+2*(j+1)];          double d01 = D[1+(i+ 0)*stride+2*(j+1)];          double d02 = D[0+(i+ 2)*stride+2*(j+1)];          double d03 = D[1+(i+ 2)*stride+2*(j+1)];          double d04 = D[0+(i+ 4)*stride+2*(j+1)];          double d05 = D[1+(i+ 4)*stride+2*(j+1)];          double d06 = D[0+(i+ 6)*stride+2*(j+1)];          double d07 = D[1+(i+ 6)*stride+2*(j+1)];          double d08 = D[0+(i+ 8)*stride+2*(j+1)];          double d09 = D[1+(i+ 8)*stride+2*(j+1)];          double d10 = D[0+(i+10)*stride+2*(j+1)];          double d11 = D[1+(i+10)*stride+2*(j+1)];          double d12 = D[0+(i+12)*stride+2*(j+1)];          double d13 = D[1+(i+12)*stride+2*(j+1)];          double d14 = D[0+(i+14)*stride+2*(j+1)];          double d15 = D[1+(i+14)*stride+2*(j+1)];        for (int k = 0; k < baseOrder; k++)        {          d00+=C[0+(i+ 0)*stride+2*k]*BT[1+j*stride+2*k];          d01+=C[1+(i+ 0)*stride+2*k]*BT[1+j*stride+2*k];          d02+=C[0+(i+ 2)*stride+2*k]*BT[1+j*stride+2*k];          d03+=C[1+(i+ 2)*stride+2*k]*BT[1+j*stride+2*k];          d04+=C[0+(i+ 4)*stride+2*k]*BT[1+j*stride+2*k];          d05+=C[1+(i+ 4)*stride+2*k]*BT[1+j*stride+2*k];          d06+=C[0+(i+ 6)*stride+2*k]*BT[1+j*stride+2*k];          d07+=C[1+(i+ 6)*stride+2*k]*BT[1+j*stride+2*k];          d08+=C[0+(i+ 8)*stride+2*k]*BT[1+j*stride+2*k];          d09+=C[1+(i+ 8)*stride+2*k]*BT[1+j*stride+2*k];          d10+=C[0+(i+10)*stride+2*k]*BT[1+j*stride+2*k];          d11+=C[1+(i+10)*stride+2*k]*BT[1+j*stride+2*k];          d12+=C[0+(i+12)*stride+2*k]*BT[1+j*stride+2*k];          d13+=C[1+(i+12)*stride+2*k]*BT[1+j*stride+2*k];          d14+=C[0+(i+14)*stride+2*k]*BT[1+j*stride+2*k];          d15+=C[1+(i+14)*stride+2*k]*BT[1+j*stride+2*k];        }          D[0+(i+ 0)*stride+2*(j+1)] = d00;          D[1+(i+ 0)*stride+2*(j+1)] = d01;          D[0+(i+ 2)*stride+2*(j+1)] = d02;          D[1+(i+ 2)*stride+2*(j+1)] = d03;          D[0+(i+ 4)*stride+2*(j+1)] = d04;          D[1+(i+ 4)*stride+2*(j+1)] = d05;          D[0+(i+ 6)*stride+2*(j+1)] = d06;          D[1+(i+ 6)*stride+2*(j+1)] = d07;          D[0+(i+ 8)*stride+2*(j+1)] = d08;          D[1+(i+ 8)*stride+2*(j+1)] = d09;          D[0+(i+10)*stride+2*(j+1)] = d10;          D[1+(i+10)*stride+2*(j+1)] = d11;          D[0+(i+12)*stride+2*(j+1)] = d12;          D[1+(i+12)*stride+2*(j+1)] = d13;          D[0+(i+14)*stride+2*(j+1)] = d14;          D[1+(i+14)*stride+2*(j+1)] = d15;        }      }  #endif  #if 0    // Begin SSE    for (int j = 0; j < baseOrder; j+=2)      for (int i = 0; i < baseOrder; i+=16)      {        {          __m128d d00 = _mm_load_pd(&D[0+(i+ 0)*stride+2*(j+0)]);          __m128d d02 = _mm_load_pd(&D[0+(i+ 2)*stride+2*(j+0)]);          __m128d d04 = _mm_load_pd(&D[0+(i+ 4)*stride+2*(j+0)]);          __m128d d06 = _mm_load_pd(&D[0+(i+ 6)*stride+2*(j+0)]);          __m128d d08 = _mm_load_pd(&D[0+(i+ 8)*stride+2*(j+0)]);          __m128d d10 = _mm_load_pd(&D[0+(i+10)*stride+2*(j+0)]);          __m128d d12 = _mm_load_pd(&D[0+(i+12)*stride+2*(j+0)]);          __m128d d14 = _mm_load_pd(&D[0+(i+14)*stride+2*(j+0)]);        for (int k = 0; k < baseOrder; k++)        {          __m128d bt0 = _mm_load1_pd(&BT[0+j*stride+2*k]);  	d00+=_mm_load_pd(&C[0+(i+ 0)*stride+2*k])*bt0;

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?