📄 red_scat.c
字号:
/* -*- Mode: C; c-basic-offset:4 ; -*- *//* * * (C) 2001 by Argonne National Laboratory. * See COPYRIGHT in top-level directory. */#include "mpiimpl.h"/* -- Begin Profiling Symbol Block for routine MPI_Reduce_scatter */#if defined(HAVE_PRAGMA_WEAK)#pragma weak MPI_Reduce_scatter = PMPI_Reduce_scatter#elif defined(HAVE_PRAGMA_HP_SEC_DEF)#pragma _HP_SECONDARY_DEF PMPI_Reduce_scatter MPI_Reduce_scatter#elif defined(HAVE_PRAGMA_CRI_DUP)#pragma _CRI duplicate MPI_Reduce_scatter as PMPI_Reduce_scatter#endif/* -- End Profiling Symbol Block *//* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build the MPI routines */#ifndef MPICH_MPI_FROM_PMPI#define MPI_Reduce_scatter PMPI_Reduce_scatter/* This is the default implementation of reduce_scatter. The algorithm is: Algorithm: MPI_Reduce_scatter If the operation is commutative, for short and medium-size messages, we use a recursive-halving algorithm in which the first p/2 processes send the second n/2 data to their counterparts in the other half and receive the first n/2 data from them. This procedure continues recursively, halving the data communicated at each step, for a total of lgp steps. If the number of processes is not a power-of-two, we convert it to the nearest lower power-of-two by having the first few even-numbered processes send their data to the neighboring odd-numbered process at (rank+1). Those odd-numbered processes compute the result for their left neighbor as well in the recursive halving algorithm, and then at the end send the result back to the processes that didn't participate. Therefore, if p is a power-of-two, Cost = lgp.alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma If p is not a power-of-two, Cost = (floor(lgp)+2).alpha + n.(1+(p-1+n)/p).beta + n.(1+(p-1)/p).gamma The above cost in the non power-of-two case is approximate because there is some imbalance in the amount of work each process does because some processes do the work of their neighbors as well. For commutative operations and very long messages we use we use a pairwise exchange algorithm similar to the one used in MPI_Alltoall. At step i, each process sends n/p amount of data to (rank+i) and receives n/p amount of data from (rank-i). Cost = (p-1).alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma If the operation is not commutative, we do the following: For very short messages, we use a recursive doubling algorithm, which takes lgp steps. At step 1, processes exchange (n-n/p) amount of data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p) amount of data, and so forth. Cost = lgp.alpha + n.(lgp-(p-1)/p).beta + n.(lgp-(p-1)/p).gamma For medium and long messages, we use pairwise exchange as above. Possible improvements: End Algorithm: MPI_Reduce_scatter*//* begin:nested *//* not declared static because a machine-specific function may call this one in some cases */int MPIR_Reduce_scatter ( void *sendbuf, void *recvbuf, int *recvcnts, MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr ){ static const char FCNAME[] = "MPIR_Reduce_scatter"; int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; int *disps; void *tmp_recvbuf, *tmp_results; int mpi_errno = MPI_SUCCESS; int type_size, dis[2], blklens[2], total_count, nbytes, src, dst; int mask, dst_tree_root, my_tree_root, j, k; int *newcnts, *newdisps, rem, newdst, send_idx, recv_idx, last_idx, send_cnt, recv_cnt; int pof2, old_i, newrank, received; MPI_Datatype sendtype, recvtype; int nprocs_completed, tmp_mask, tree_root, is_commutative; MPI_User_function *uop; MPID_Op *op_ptr; MPI_Comm comm; MPICH_PerThread_t *p;#ifdef HAVE_CXX_BINDING int is_cxx_uop = 0;#endif comm = comm_ptr->handle; comm_size = comm_ptr->local_size; rank = comm_ptr->rank; /* set op_errno to 0. stored in perthread structure */ MPIR_GetPerThread(&p); p->op_errno = 0; MPID_Datatype_get_extent_macro(datatype, extent); mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) { is_commutative = 1; /* get the function by indexing into the op table */ uop = MPIR_Op_table[op%16 - 1]; } else { MPID_Op_get_ptr(op, op_ptr); if (op_ptr->kind == MPID_OP_USER_NONCOMMUTE) is_commutative = 0; else is_commutative = 1;#ifdef HAVE_CXX_BINDING if (op_ptr->language == MPID_LANG_CXX) { uop = (MPI_User_function *) op_ptr->function.c_function; is_cxx_uop = 1; } else#endif if ((op_ptr->language == MPID_LANG_C)) uop = (MPI_User_function *) op_ptr->function.c_function; else uop = (MPI_User_function *) op_ptr->function.f77_function; } disps = MPIU_Malloc(comm_size*sizeof(int)); /* --BEGIN ERROR HANDLING-- */ if (!disps) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ total_count = 0; for (i=0; i<comm_size; i++) { disps[i] = total_count; total_count += recvcnts[i]; } if (total_count == 0) { MPIU_Free(disps); return MPI_SUCCESS; } MPID_Datatype_get_size_macro(datatype, type_size); nbytes = total_count * type_size; /* check if multiple threads are calling this collective function */ MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr ); MPIR_Nest_incr(); if ((is_commutative) && (nbytes < MPIR_REDSCAT_COMMUTATIVE_LONG_MSG)) { /* commutative and short. use recursive halving algorithm */ /* allocate temp. buffer to receive incoming data */ tmp_recvbuf = MPIU_Malloc(total_count*(MPIR_MAX(true_extent,extent))); /* --BEGIN ERROR HANDLING-- */ if (!tmp_recvbuf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ /* adjust for potential negative lower bound in datatype */ tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb); /* need to allocate another temporary buffer to accumulate results because recvbuf may not be big enough */ tmp_results = MPIU_Malloc(total_count*(MPIR_MAX(true_extent,extent))); /* --BEGIN ERROR HANDLING-- */ if (!tmp_results) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ /* adjust for potential negative lower bound in datatype */ tmp_results = (void *)((char*)tmp_results - true_lb); /* copy sendbuf into tmp_results */ if (sendbuf != MPI_IN_PLACE) mpi_errno = MPIR_Localcopy(sendbuf, total_count, datatype, tmp_results, total_count, datatype); else mpi_errno = MPIR_Localcopy(recvbuf, total_count, datatype, tmp_results, total_count, datatype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ pof2 = 1; while (pof2 <= comm_size) pof2 <<= 1; pof2 >>=1; rem = comm_size - pof2; /* In the non-power-of-two case, all even-numbered processes of rank < 2*rem send their data to (rank+1). These even-numbered processes no longer participate in the algorithm until the very end. The remaining processes form a nice power-of-two. */ if (rank < 2*rem) { if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(tmp_results, total_count, datatype, rank+1, MPIR_REDUCE_SCATTER_TAG, comm); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ /* temporarily set the rank to -1 so that this process does not pariticipate in recursive doubling */ newrank = -1; } else { /* odd */ mpi_errno = MPIC_Recv(tmp_recvbuf, total_count, datatype, rank-1, MPIR_REDUCE_SCATTER_TAG, comm, MPI_STATUS_IGNORE); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ /* do the reduction on received data. since the ordering is right, it doesn't matter whether the operation is commutative or not. */#ifdef HAVE_CXX_BINDING if (is_cxx_uop) { (*MPIR_Process.cxx_call_op_fn)( tmp_recvbuf, tmp_results, total_count, datatype, uop ); } else #endif (*uop)(tmp_recvbuf, tmp_results, &total_count, &datatype); /* change the rank */ newrank = rank / 2; } } else /* rank >= 2*rem */ newrank = rank - rem; if (newrank != -1) { /* recalculate the recvcnts and disps arrays because the even-numbered processes who no longer participate will have their result calculated by the process to their right (rank+1). */ newcnts = (int *) MPIU_Malloc(pof2*sizeof(int)); /* --BEGIN ERROR HANDLING-- */ if (!newcnts) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ newdisps = (int *) MPIU_Malloc(pof2*sizeof(int)); /* --BEGIN ERROR HANDLING-- */ if (!newdisps) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ for (i=0; i<pof2; i++) { /* what does i map to in the old ranking? */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -