cvmatmul.cpp.svn-base

来自「非结构化路识别」· SVN-BASE 代码 · 共 1,586 行 · 第 1/5 页

SVN-BASE
1,586
字号

                if( data3 )
                {
                    t[0] = (float)(A(0,0)*X(0) + A(0,1)*X(1) + A(0,2)*X(2) + B(0));
                    t[1] = (float)(A(1,0)*X(0) + A(1,1)*X(1) + A(1,2)*X(2) + B(1));
                    t[2] = (float)(A(2,0)*X(0) + A(2,1)*X(1) + A(2,2)*X(2) + B(2));
                }
                else
                {
                    t[0] = (float)(A(0,0)*X(0) + A(0,1)*X(1) + A(0,2)*X(2));
                    t[1] = (float)(A(1,0)*X(0) + A(1,1)*X(1) + A(1,2)*X(2));
                    t[2] = (float)(A(2,0)*X(0) + A(2,1)*X(1) + A(2,2)*X(2));
                }

                Y(0) = t[0];
                Y(1) = t[1];
                Y(2) = t[2];
            }
            else
            {
                Y(0) = data3 ? (float)(A(0,0)*X(0) + B(0)) : (float)(A(0,0)*X(0));
            }
            EXIT;
        }
        
        if( type == CV_64FC1 )
        {
            if( src1->width == 2 )
            {
                #undef N
                #define N  2

                #undef arrtype                
                #define arrtype double

                double t[2];

                if( data3 )
                {
                    t[0] = A(0,0)*X(0) + A(0,1)*X(1) + B(0);
                    t[1] = A(1,0)*X(0) + A(1,1)*X(1) + B(1);
                }
                else
                {
                    t[0] = A(0,0)*X(0) + A(0,1)*X(1);
                    t[1] = A(1,0)*X(0) + A(1,1)*X(1);
                }

                Y(0) = t[0];
                Y(1) = t[1];
            }
            else if( src1->width == 3 )
            {
                #undef N
                #define N  3

                double t[3];

                if( data3 )
                {
                    t[0] = A(0,0)*X(0) + A(0,1)*X(1) + A(0,2)*X(2) + B(0);
                    t[1] = A(1,0)*X(0) + A(1,1)*X(1) + A(1,2)*X(2) + B(1);
                    t[2] = A(2,0)*X(0) + A(2,1)*X(1) + A(2,2)*X(2) + B(2);
                }
                else
                {
                    t[0] = A(0,0)*X(0) + A(0,1)*X(1) + A(0,2)*X(2);
                    t[1] = A(1,0)*X(0) + A(1,1)*X(1) + A(1,2)*X(2);
                    t[2] = A(2,0)*X(0) + A(2,1)*X(1) + A(2,2)*X(2);
                }

                Y(0) = t[0];
                Y(1) = t[1];
                Y(2) = t[2];
            }
            else
            {
                Y(0) = data3 ? A(0,0)*X(0) + B(0) : A(0,0)*X(0);
            }

            EXIT;
        }
    }

    // general case
    {
        CvMatMulAddFunc func;
        CvSize size = icvGetMatSize( src1 );
        CvSize dstsize = icvGetMatSize( dst );
        CvMat tmat, *tdst = dst;
        
        if( !inittab )
        {
            icvInitMatMulAddTable( &mmuladd_tab );
            inittab = 1;
        }

        if( dst->data.ptr == src1->data.ptr || dst->data.ptr == src2->data.ptr )
        {
            int buf_size = dstsize.width*dstsize.height*icvPixSize[type];
            if( buf_size <= CV_MAX_LOCAL_SIZE )
            {
                buffer = (uchar*)alloca( buf_size + 8 );
                buffer = (uchar*)icvAlignPtr( buffer, 8 );
                local_alloc = 1;
            }
            else
            {
                CV_CALL( buffer = (uchar*)cvAlloc( buf_size ));
            }

            CV_CALL( cvInitMatHeader( &tmat, dstsize.height,
                                      dstsize.width, type, buffer ));
            tdst = &tmat;
        }

        func = (CvMatMulAddFunc)(mmuladd_tab.fn_2d[type]);
        if( !func )
            CV_ERROR( CV_StsUnsupportedFormat, "" );

        IPPI_CALL( func( src1->data.ptr, src1->step, src2->data.ptr, src2->step,
                         src3->data.ptr, src3->step, tdst->data.ptr, tdst->step,
                         size, dstsize ));

        if( tdst != dst )
        {
            CV_CALL( cvCopy( tdst, dst ));
        }
    }

    CV_CHECK_NANS( dst );

    __END__;

    if( buffer && !local_alloc )
        cvFree( (void**)&buffer );
}


/****************************************************************************************\
*                                         cvGEMM                                         *
\****************************************************************************************/

#define IPCV_GEMM( flavor )                                                      \
IPCVAPI( CvStatus,                                                               \
icvGEMM_##flavor,( const uchar* src1, int step1,                                 \
                   const uchar* src2, int step2, double alpha,                   \
                   const uchar* src3, int step3, double beta,                    \
                   uchar* dst, int dststep,                                      \
                   CvSize srcsize, CvSize dstsize, int flags ))


IPCV_GEMM( 32f_C1R )
IPCV_GEMM( 64f_C1R )


#undef IPCV_GEMM


#define ICV_DEF_GEMM_FUNC( flavor, arrtype, temptype,                                   \
                           _mul_add_macro_, _load_macro_, _store_macro_, cn )           \
IPCVAPI_IMPL( CvStatus,                                                                 \
icvGEMM_##flavor,( const uchar* src1, int step1, const uchar* src2, int step2,          \
                   double alpha, const uchar* src3, int step3, double beta,             \
                   uchar* dst, int step, CvSize srcsize, CvSize dstsize, int flags ))   \
{                                                                                       \
    int delta1 = sizeof(arrtype)*(cn), delta2 = sizeof(arrtype)*(cn),                   \
        delta3 = sizeof(arrtype)*(cn);                                                  \
    arrtype zero = 0;                                                                   \
                                                                                        \
    if( flags & 1 )                                                                     \
    {                                                                                   \
        int t;                                                                          \
        delta1 = step1, step1 = sizeof(arrtype)*(cn);                                   \
        CV_SWAP( srcsize.width, srcsize.height, t );                                    \
    }                                                                                   \
                                                                                        \
    if( flags & 2 )                                                                     \
        delta2 = step2, step2 = sizeof(arrtype)*(cn);                                   \
                                                                                        \
    if( !src3 )                                                                         \
        src3 = (uchar*)&zero, delta3 = step3 = 0;                                       \
    else if( flags & 4 )                                                                \
        delta3 = step3, step3 = sizeof(arrtype)*(cn);                                   \
                                                                                        \
    for( ; dstsize.height--; src1 += step1, dst += step )                               \
    {                                                                                   \
        int i, j;                                                                       \
        const uchar* src2_00 = src2;                                                    \
        const uchar* src3_0 = src3;                                                     \
                                                                                        \
        for( i = 0; i < dstsize.width; i++, src3 += delta3 )                            \
        {                                                                               \
            temptype sum[1] = {0};                                                      \
            const uchar* src1_0 = src1;                                                 \
            const uchar* src2_0 = src2;                                                 \
                                                                                        \
            for( j = 0; j < srcsize.width; j++, src1 += delta1, src2 += step2 )         \
            {                                                                           \
                _mul_add_macro_( (arrtype*)src1, (arrtype*)src2, sum );                 \
            }                                                                           \
                                                                                        \
            ((arrtype*)dst)[i*(cn)] = (arrtype)( sum[0]*alpha + *(arrtype*)src3*beta ); \
            src1 = src1_0;                                                              \
            src2 = src2_0 + delta2;                                                     \
        }                                                                               \
        src2 = src2_00;                                                                 \
        src3 = src3_0 + step3;                                                          \
    }                                                                                   \
                                                                                        \
    return CV_OK;                                                                       \
}


ICV_DEF_GEMM_FUNC( 32f_C1R, float, double, _mul_add_real_,
                   _load_real_, _store_real_, 1)
ICV_DEF_GEMM_FUNC( 64f_C1R, double, double, _mul_add_real_,
                   _load_real_,_store_real_, 1)

typedef CvStatus (CV_STDCALL *CvGEMMFunc)( const void* src1, int step1,
                   const void* src2, int step2, double alpha,
                   const void* src3, int step3, double beta,
                   void* dst, int dststep, CvSize srcsize, CvSize dstsize, int f );

#define ICV_DEF_INIT_GEMM_TAB( FUNCNAME )                   \
static void icvInitGEMMTable( CvFuncTable* table )          \
{                                                           \
    table->fn_2d[CV_32F] = (void*)icv##FUNCNAME##_32f_C1R;  \
    table->fn_2d[CV_64F] = (void*)icv##FUNCNAME##_64f_C1R;  \
}


ICV_DEF_INIT_GEMM_TAB( GEMM )


CV_IMPL  void
cvGEMM( const CvArr* src1arr, const CvArr* src2arr, double alpha,
        const CvArr* src3arr, double beta, CvArr* dstarr, int tABC )
{
    static CvFuncTable gemm_tab;
    static int inittab = 0;
    
    uchar* buffer = 0;
    int local_alloc = 0;
    
    CV_FUNCNAME( "cvGEMM" );

    __BEGIN__;

    CvMat stub1, *src1 = (CvMat*)src1arr;
    CvMat stub2, *src2 = (CvMat*)src2arr;
    CvMat stub3, *src3 = (CvMat*)src3arr;
    CvMat stub, *dst = (CvMat*)dstarr;
    int type;

    if( !CV_IS_MAT( src1 ))
    {
        int coi = 0;
        CV_CALL( src1 = cvGetMat( src1, &stub1, &coi ));

        if( coi != 0 )
            CV_ERROR( CV_BadCOI, "" );
    }

    if( !CV_IS_MAT( src2 ))
    {
        int coi = 0;
        CV_CALL( src2 = cvGetMat( src2, &stub2, &coi ));

        if( coi != 0 )
            CV_ERROR( CV_BadCOI, "" );
    }

    if( !CV_IS_MAT( dst ))
    {
        int coi = 0;
        CV_CALL( dst = cvGetMat( dst, &stub, &coi ));

        if( coi != 0 )
            CV_ERROR( CV_BadCOI, "" );
    }

    if( src3 )
    {
        if( !CV_IS_MAT( src3 ))
        {
            int coi = 0;
            CV_CALL( src3 = cvGetMat( src3, &stub3, &coi ));

            if( coi != 0 )
                CV_ERROR( CV_BadCOI, "" );
        }

        if( !CV_ARE_TYPES_EQ( src3, dst ))
            CV_ERROR( CV_StsUnmatchedFormats, "" );

        if( (tABC&4) == 0 && (src3->width != dst->width || src3->height != dst->height) ||
            (tABC&4) != 0 && (src3->height != dst->width || src3->width != dst->height))
            CV_ERROR( CV_StsUnmatchedSizes, "" );
    }
    else
    {
        src3 = &stub3;
        src3->data.ptr = 0;
        src3->step = 0;
        src3->type = CV_MAT_CONT_FLAG;
    }

    if( !CV_ARE_TYPES_EQ( src1, src2 ))
        CV_ERROR( CV_StsUnmatchedFormats, "" );

    if( !CV_ARE_TYPES_EQ( src1, dst ))
        CV_ERROR( CV_StsUnmatchedFormats, "" );

    switch( tABC & 3 )
    {
    case 0:
        if( src1->width != src2->height ||

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?