⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 allreduce.c

📁 mpi并行计算的c++代码 可用vc或gcc编译通过 可以用来搭建并行计算试验环境
💻 C
📖 第 1 页 / 共 2 页
字号:
/* -*- Mode: C; c-basic-offset:4 ; -*- *//*  $Id: allreduce.c,v 1.64 2005/11/07 23:37:24 thakur Exp $ * *  (C) 2001 by Argonne National Laboratory. *      See COPYRIGHT in top-level directory. */#include "mpiimpl.h"/* -- Begin Profiling Symbol Block for routine MPI_Allreduce */#if defined(HAVE_PRAGMA_WEAK)#pragma weak MPI_Allreduce = PMPI_Allreduce#elif defined(HAVE_PRAGMA_HP_SEC_DEF)#pragma _HP_SECONDARY_DEF PMPI_Allreduce  MPI_Allreduce#elif defined(HAVE_PRAGMA_CRI_DUP)#pragma _CRI duplicate MPI_Allreduce as PMPI_Allreduce#endif/* -- End Profiling Symbol Block *//* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build   the MPI routines */#ifndef MPICH_MPI_FROM_PMPI#define MPI_Allreduce PMPI_AllreduceMPI_User_function *MPIR_Op_table[] = { MPIR_MAXF, MPIR_MINF, MPIR_SUM,                                       MPIR_PROD, MPIR_LAND,                                       MPIR_BAND, MPIR_LOR, MPIR_BOR,                                       MPIR_LXOR, MPIR_BXOR,                                       MPIR_MINLOC, MPIR_MAXLOC, };MPIR_Op_check_dtype_fn *MPIR_Op_check_dtype_table[] = {    MPIR_MAXF_check_dtype, MPIR_MINF_check_dtype,    MPIR_SUM_check_dtype,    MPIR_PROD_check_dtype, MPIR_LAND_check_dtype,    MPIR_BAND_check_dtype, MPIR_LOR_check_dtype, MPIR_BOR_check_dtype,    MPIR_LXOR_check_dtype, MPIR_BXOR_check_dtype,    MPIR_MINLOC_check_dtype, MPIR_MAXLOC_check_dtype, }; /* This is the default implementation of allreduce. The algorithm is:      Algorithm: MPI_Allreduce   For the heterogeneous case, we call MPI_Reduce followed by MPI_Bcast   in order to meet the requirement that all processes must have the   same result. For the homogeneous case, we use the following algorithms.   For long messages and for builtin ops and if count >= pof2 (where   pof2 is the nearest power-of-two less than or equal to the number   of processes), we use Rabenseifner's algorithm (see    http://www.hlrs.de/organization/par/services/models/mpi/myreduce.html ).   This algorithm implements the allreduce in two steps: first a   reduce-scatter, followed by an allgather. A recursive-halving   algorithm (beginning with processes that are distance 1 apart) is   used for the reduce-scatter, and a recursive doubling    algorithm is used for the allgather. The non-power-of-two case is   handled by dropping to the nearest lower power-of-two: the first   few even-numbered processes send their data to their right neighbors   (rank+1), and the reduce-scatter and allgather happen among the remaining   power-of-two processes. At the end, the first few even-numbered   processes get the result from their right neighbors.   For the power-of-two case, the cost for the reduce-scatter is    lgp.alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma. The cost for the   allgather lgp.alpha + n.((p-1)/p).beta. Therefore, the   total cost is:   Cost = 2.lgp.alpha + 2.n.((p-1)/p).beta + n.((p-1)/p).gamma   For the non-power-of-two case,    Cost = (2.floor(lgp)+2).alpha + (2.((p-1)/p) + 2).n.beta + n.(1+(p-1)/p).gamma      For short messages, for user-defined ops, and for count < pof2    we use a recursive doubling algorithm (similar to the one in   MPI_Allgather). We use this algorithm in the case of user-defined ops   because in this case derived datatypes are allowed, and the user   could pass basic datatypes on one process and derived on another as   long as the type maps are the same. Breaking up derived datatypes   to do the reduce-scatter is tricky.    Cost = lgp.alpha + n.lgp.beta + n.lgp.gamma   Possible improvements:    End Algorithm: MPI_Allreduce*//* not declared static because a machine-specific function may call this one in some cases */int MPIR_Allreduce (     void *sendbuf,     void *recvbuf,     int count,     MPI_Datatype datatype,     MPI_Op op,     MPID_Comm *comm_ptr ){    static const char FCNAME[] = "MPIR_Allreduce";    int is_homogeneous;#ifdef MPID_HAS_HETERO    int rc;#endif    int        comm_size, rank, type_size;    int        mpi_errno = MPI_SUCCESS;    int mask, dst, is_commutative, pof2, newrank, rem, newdst, i,        send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;     MPI_Aint true_extent, true_lb, extent;    void *tmp_buf;    MPI_User_function *uop;    MPID_Op *op_ptr;    MPI_Comm comm;    MPICH_PerThread_t *p;#ifdef HAVE_CXX_BINDING    int is_cxx_uop = 0;#endif    MPIU_CHKLMEM_DECL(3);        if (count == 0) return MPI_SUCCESS;    comm = comm_ptr->handle;    MPIR_Nest_incr();        is_homogeneous = 1;#ifdef MPID_HAS_HETERO    if (comm_ptr->is_hetero)        is_homogeneous = 0;#endif    #ifdef MPID_HAS_HETERO    if (!is_homogeneous) {        /* heterogeneous. To get the same result on all processes, we           do a reduce to 0 and then broadcast. */        mpi_errno = NMPI_Reduce ( sendbuf, recvbuf, count, datatype,                                  op, 0, comm );	/* FIXME: mpi_errno is error CODE, not necessarily the error	   class MPI_ERR_OP.  In MPICH2, we can get the error class 	   with 	       errorclass = mpi_errno & ERROR_CLASS_MASK;	*/        if (mpi_errno == MPI_ERR_OP || mpi_errno == MPI_SUCCESS) {	    /* Allow MPI_ERR_OP since we can continue from this error */            rc = NMPI_Bcast  ( recvbuf, count, datatype, 0, comm );            if (rc) mpi_errno = rc;        }    }    else #endif /* MPID_HAS_HETERO */    {        /* homogeneous */                /* set op_errno to 0. stored in perthread structure */        MPIR_GetPerThread(&p);        p->op_errno = 0;        comm_size = comm_ptr->local_size;        rank = comm_ptr->rank;                if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) {            is_commutative = 1;            /* get the function by indexing into the op table */            uop = MPIR_Op_table[op%16 - 1];        }        else {            MPID_Op_get_ptr(op, op_ptr);            if (op_ptr->kind == MPID_OP_USER_NONCOMMUTE)                is_commutative = 0;            else                is_commutative = 1;#ifdef HAVE_CXX_BINDING                        if (op_ptr->language == MPID_LANG_CXX) {                uop = (MPI_User_function *) op_ptr->function.c_function;		is_cxx_uop = 1;	    }	    else#endif            if ((op_ptr->language == MPID_LANG_C))                uop = (MPI_User_function *) op_ptr->function.c_function;            else                uop = (MPI_User_function *) op_ptr->function.f77_function;        }                /* need to allocate temporary buffer to store incoming data*/        mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb,                                              &true_extent);	MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail");        MPID_Datatype_get_extent_macro(datatype, extent);        MPIU_CHKLMEM_MALLOC(tmp_buf, void *, count*(MPIR_MAX(extent,true_extent)), mpi_errno, "temporary buffer");	        /* adjust for potential negative lower bound in datatype */        tmp_buf = (void *)((char*)tmp_buf - true_lb);                /* copy local data into recvbuf */        if (sendbuf != MPI_IN_PLACE) {            mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf,                                       count, datatype);            MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail");        }        MPID_Datatype_get_size_macro(datatype, type_size);        /* find nearest power-of-two less than or equal to comm_size */        pof2 = 1;        while (pof2 <= comm_size) pof2 <<= 1;        pof2 >>=1;        rem = comm_size - pof2;        /* In the non-power-of-two case, all even-numbered           processes of rank < 2*rem send their data to           (rank+1). These even-numbered processes no longer           participate in the algorithm until the very end. The           remaining processes form a nice power-of-two. */                /* check if multiple threads are calling this collective function */        MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );        if (rank < 2*rem) {            if (rank % 2 == 0) { /* even */                mpi_errno = MPIC_Send(recvbuf, count,                                       datatype, rank+1,                                      MPIR_ALLREDUCE_TAG, comm);		MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail");                                /* temporarily set the rank to -1 so that this                   process does not pariticipate in recursive                   doubling */                newrank = -1;             }            else { /* odd */                mpi_errno = MPIC_Recv(tmp_buf, count,                                       datatype, rank-1,                                      MPIR_ALLREDUCE_TAG, comm,                                      MPI_STATUS_IGNORE);		MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail");                                /* do the reduction on received data. since the                   ordering is right, it doesn't matter whether                   the operation is commutative or not. */#ifdef HAVE_CXX_BINDING                if (is_cxx_uop) {                    (*MPIR_Process.cxx_call_op_fn)( tmp_buf, recvbuf,                                                     count,                                                    datatype,                                                    uop );                 }                else #endif                    (*uop)(tmp_buf, recvbuf, &count, &datatype);                                /* change the rank */                newrank = rank / 2;            }        }        else  /* rank >= 2*rem */            newrank = rank - rem;                /* If op is user-defined or count is less than pof2, use           recursive doubling algorithm. Otherwise do a reduce-scatter           followed by allgather. (If op is user-defined,           derived datatypes are allowed and the user could pass basic           datatypes on one process and derived on another as long as           the type maps are the same. Breaking up derived           datatypes to do the reduce-scatter is tricky, therefore           using recursive doubling in that case.) */        if (newrank != -1) {            if ((count*type_size <= MPIR_ALLREDUCE_SHORT_MSG) ||                (HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||                  (count < pof2)) { /* use recursive doubling */                mask = 0x1;                while (mask < pof2) {                    newdst = newrank ^ mask;                    /* find real rank of dest */                    dst = (newdst < rem) ? newdst*2 + 1 : newdst + rem;                    /* Send the most current data, which is in recvbuf. Recv                       into tmp_buf */                     mpi_errno = MPIC_Sendrecv(recvbuf, count, datatype,                                               dst, MPIR_ALLREDUCE_TAG, tmp_buf,                                              count, datatype, dst,                                              MPIR_ALLREDUCE_TAG, comm,                                              MPI_STATUS_IGNORE);		    MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail");                                        /* tmp_buf contains data received in this step.                       recvbuf contains data accumulated so far */                                        if (is_commutative  || (dst < rank)) {                        /* op is commutative OR the order is already right */#ifdef HAVE_CXX_BINDING                        if (is_cxx_uop) {                            (*MPIR_Process.cxx_call_op_fn)( tmp_buf, recvbuf,                                                             count,                                                            datatype,                                                            uop );                         }                        else #endif                            (*uop)(tmp_buf, recvbuf, &count, &datatype);                    }                    else {                        /* op is noncommutative and the order is not right */#ifdef HAVE_CXX_BINDING                        if (is_cxx_uop) {                            (*MPIR_Process.cxx_call_op_fn)( recvbuf, tmp_buf,                                                             count,                                                            datatype,                                                            uop );                         }                        else #endif                            (*uop)(recvbuf, tmp_buf, &count, &datatype);                                                /* copy result back into recvbuf */                        mpi_errno = MPIR_Localcopy(tmp_buf, count, datatype,                                                   recvbuf, count, datatype);			MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail");                    }                    mask <<= 1;                }            }            else {                /* do a reduce-scatter followed by allgather */                /* for the reduce-scatter, calculate the count that                   each process receives and the displacement within                   the buffer */		MPIU_CHKLMEM_MALLOC(cnts, int *, pof2*sizeof(int), mpi_errno, "counts");		MPIU_CHKLMEM_MALLOC(disps, int *, pof2*sizeof(int), mpi_errno, "displacements");                for (i=0; i<(pof2-1); i++)                     cnts[i] = count/pof2;                cnts[pof2-1] = count - (count/pof2)*(pof2-1);                disps[0] = 0;                for (i=1; i<pof2; i++)                    disps[i] = disps[i-1] + cnts[i-1];                mask = 0x1;                send_idx = recv_idx = 0;                last_idx = pof2;                while (mask < pof2) {                    newdst = newrank ^ mask;                    /* find real rank of dest */                    dst = (newdst < rem) ? newdst*2 + 1 : newdst + rem;                    send_cnt = recv_cnt = 0;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -