📄 reduce.c
字号:
/* -*- Mode: C; c-basic-offset:4 ; -*- *//* * * (C) 2001 by Argonne National Laboratory. * See COPYRIGHT in top-level directory. */#include "mpiimpl.h"/* -- Begin Profiling Symbol Block for routine MPI_Reduce */#if defined(HAVE_PRAGMA_WEAK)#pragma weak MPI_Reduce = PMPI_Reduce#elif defined(HAVE_PRAGMA_HP_SEC_DEF)#pragma _HP_SECONDARY_DEF PMPI_Reduce MPI_Reduce#elif defined(HAVE_PRAGMA_CRI_DUP)#pragma _CRI duplicate MPI_Reduce as PMPI_Reduce#endif/* -- End Profiling Symbol Block *//* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build the MPI routines */#ifndef MPICH_MPI_FROM_PMPI#define MPI_Reduce PMPI_Reduce/* This is the default implementation of reduce. The algorithm is: Algorithm: MPI_Reduce For long messages and for builtin ops and if count >= pof2 (where pof2 is the nearest power-of-two less than or equal to the number of processes), we use Rabenseifner's algorithm (see http://www.hlrs.de/organization/par/services/models/mpi/myreduce.html ). This algorithm implements the reduce in two steps: first a reduce-scatter, followed by a gather to the root. A recursive-halving algorithm (beginning with processes that are distance 1 apart) is used for the reduce-scatter, and a binomial tree algorithm is used for the gather. The non-power-of-two case is handled by dropping to the nearest lower power-of-two: the first few odd-numbered processes send their data to their left neighbors (rank-1), and the reduce-scatter happens among the remaining power-of-two processes. If the root is one of the excluded processes, then after the reduce-scatter, rank 0 sends its result to the root and exits; the root now acts as rank 0 in the binomial tree algorithm for gather. For the power-of-two case, the cost for the reduce-scatter is lgp.alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma. The cost for the gather to root is lgp.alpha + n.((p-1)/p).beta. Therefore, the total cost is: Cost = 2.lgp.alpha + 2.n.((p-1)/p).beta + n.((p-1)/p).gamma For the non-power-of-two case, assuming the root is not one of the odd-numbered processes that get excluded in the reduce-scatter, Cost = (2.floor(lgp)+1).alpha + (2.((p-1)/p) + 1).n.beta + n.(1+(p-1)/p).gamma For short messages, user-defined ops, and count < pof2, we use a binomial tree algorithm for both short and long messages. Cost = lgp.alpha + n.lgp.beta + n.lgp.gamma We use the binomial tree algorithm in the case of user-defined ops because in this case derived datatypes are allowed, and the user could pass basic datatypes on one process and derived on another as long as the type maps are the same. Breaking up derived datatypes to do the reduce-scatter is tricky. Possible improvements: End Algorithm: MPI_Reduce*//* begin:nested *//* not declared static because a machine-specific function may call this one in some cases */int MPIR_Reduce ( void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPID_Comm *comm_ptr ){ static const char FCNAME[] = "MPIR_Reduce"; MPI_Status status; int comm_size, rank, is_commutative, type_size, pof2, rem, newrank; int mask, relrank, source, lroot, *cnts, *disps, i, j, send_idx=0; int mpi_errno = MPI_SUCCESS, recv_idx, last_idx=0, newdst; int dst, send_cnt, recv_cnt, newroot, newdst_tree_root, newroot_tree_root; MPI_User_function *uop; MPI_Aint true_lb, true_extent, extent; void *tmp_buf; MPID_Op *op_ptr; MPI_Comm comm; MPICH_PerThread_t *p;#ifdef HAVE_CXX_BINDING int is_cxx_uop = 0;#endif MPIU_CHKLMEM_DECL(4); if (count == 0) return MPI_SUCCESS; MPIR_Nest_incr(); comm = comm_ptr->handle; comm_size = comm_ptr->local_size; rank = comm_ptr->rank; /* set op_errno to 0. stored in perthread structure */ MPIR_GetPerThread(&p); p->op_errno = 0; if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) { is_commutative = 1; /* get the function by indexing into the op table */ uop = MPIR_Op_table[op%16 - 1]; } else { MPID_Op_get_ptr(op, op_ptr); if (op_ptr->kind == MPID_OP_USER_NONCOMMUTE) is_commutative = 0; else is_commutative = 1; #ifdef HAVE_CXX_BINDING if (op_ptr->language == MPID_LANG_CXX) { uop = (MPI_User_function *) op_ptr->function.c_function; is_cxx_uop = 1; } else#endif if ((op_ptr->language == MPID_LANG_C)) uop = (MPI_User_function *) op_ptr->function.c_function; else uop = (MPI_User_function *) op_ptr->function.f77_function; } /* Create a temporary buffer */ mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); MPID_Datatype_get_extent_macro(datatype, extent); MPIU_CHKLMEM_MALLOC(tmp_buf, void *, count*(MPIR_MAX(extent,true_extent)), mpi_errno, "temporary buffer"); /* adjust for potential negative lower bound in datatype */ tmp_buf = (void *)((char*)tmp_buf - true_lb); /* If I'm not the root, then my recvbuf may not be valid, therefore I have to allocate a temporary one */ if (rank != root) { MPIU_CHKLMEM_MALLOC(recvbuf, void *, count*(MPIR_MAX(extent,true_extent)), mpi_errno, "receive buffer"); recvbuf = (void *)((char*)recvbuf - true_lb); } if ((rank != root) || (sendbuf != MPI_IN_PLACE)) { mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf, count, datatype); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); } MPID_Datatype_get_size_macro(datatype, type_size); /* find nearest power-of-two less than or equal to comm_size */ pof2 = 1; while (pof2 <= comm_size) pof2 <<= 1; pof2 >>=1; /* check if multiple threads are calling this collective function */ MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr ); if ((count*type_size > MPIR_REDUCE_SHORT_MSG) && (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) && (count >= pof2)) { /* do a reduce-scatter followed by gather to root. */ rem = comm_size - pof2; /* In the non-power-of-two case, all odd-numbered processes of rank < 2*rem send their data to (rank-1). These odd-numbered processes no longer participate in the algorithm until the very end. The remaining processes form a nice power-of-two. Note that in MPI_Allreduce we have the even-numbered processes send data to odd-numbered processes. That is better for non-commutative operations because it doesn't require a buffer copy. However, for MPI_Reduce, the most common case is commutative operations with root=0. Therefore we want even-numbered processes to participate the computation for the root=0 case, in order to avoid an extra send-to-root communication after the reduce-scatter. In MPI_Allreduce it doesn't matter because all processes must get the result. */ if (rank < 2*rem) { if (rank % 2 != 0) { /* odd */ mpi_errno = MPIC_Send(recvbuf, count, datatype, rank-1, MPIR_REDUCE_TAG, comm); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); /* temporarily set the rank to -1 so that this process does not pariticipate in recursive doubling */ newrank = -1; } else { /* even */ mpi_errno = MPIC_Recv(tmp_buf, count, datatype, rank+1, MPIR_REDUCE_TAG, comm, MPI_STATUS_IGNORE); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); /* do the reduction on received data. */ /* This algorithm is used only for predefined ops and predefined ops are always commutative. */#ifdef HAVE_CXX_BINDING if (is_cxx_uop) { (*MPIR_Process.cxx_call_op_fn)( tmp_buf, recvbuf, count, datatype, uop ); } else #endif (*uop)(tmp_buf, recvbuf, &count, &datatype); /* change the rank */ newrank = rank / 2; } } else /* rank >= 2*rem */ newrank = rank - rem; /* for the reduce-scatter, calculate the count that each process receives and the displacement within the buffer */ /* We allocate these arrays on all processes, even if newrank=-1, because if root is one of the excluded processes, we will need them on the root later on below. */ MPIU_CHKLMEM_MALLOC(cnts, int *, pof2*sizeof(int), mpi_errno, "counts"); MPIU_CHKLMEM_MALLOC(disps, int *, pof2*sizeof(int), mpi_errno, "displacements"); if (newrank != -1) { for (i=0; i<(pof2-1); i++) cnts[i] = count/pof2; cnts[pof2-1] = count - (count/pof2)*(pof2-1); disps[0] = 0; for (i=1; i<pof2; i++) disps[i] = disps[i-1] + cnts[i-1]; mask = 0x1; send_idx = recv_idx = 0; last_idx = pof2; while (mask < pof2) { newdst = newrank ^ mask; /* find real rank of dest */ dst = (newdst < rem) ? newdst*2 : newdst + rem; send_cnt = recv_cnt = 0; if (newrank < newdst) { send_idx = recv_idx + pof2/(mask*2); for (i=send_idx; i<last_idx; i++) send_cnt += cnts[i]; for (i=recv_idx; i<send_idx; i++) recv_cnt += cnts[i]; } else { recv_idx = send_idx + pof2/(mask*2); for (i=send_idx; i<recv_idx; i++) send_cnt += cnts[i]; for (i=recv_idx; i<last_idx; i++) recv_cnt += cnts[i]; } /* printf("Rank %d, send_idx %d, recv_idx %d, send_cnt %d, recv_cnt %d, last_idx %d\n", newrank, send_idx, recv_idx, send_cnt, recv_cnt, last_idx);*/ /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIC_Sendrecv((char *) recvbuf + disps[send_idx]*extent, send_cnt, datatype,
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -