📄 red_scat.c
字号:
/* -*- Mode: C; c-basic-offset:4 ; -*- *//* * * (C) 2001 by Argonne National Laboratory. * See COPYRIGHT in top-level directory. */#include "mpiimpl.h"/* -- Begin Profiling Symbol Block for routine MPI_Reduce_scatter */#if defined(HAVE_PRAGMA_WEAK)#pragma weak MPI_Reduce_scatter = PMPI_Reduce_scatter#elif defined(HAVE_PRAGMA_HP_SEC_DEF)#pragma _HP_SECONDARY_DEF PMPI_Reduce_scatter MPI_Reduce_scatter#elif defined(HAVE_PRAGMA_CRI_DUP)#pragma _CRI duplicate MPI_Reduce_scatter as PMPI_Reduce_scatter#endif/* -- End Profiling Symbol Block *//* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build the MPI routines */#ifndef MPICH_MPI_FROM_PMPI#define MPI_Reduce_scatter PMPI_Reduce_scatter/* This is the default implementation of reduce_scatter. The algorithm is: Algorithm: MPI_Reduce_scatter For long messages, we use a pairwise exchange algorithm similar to the one used in MPI_Alltoall. At step i, each process sends n/p amount of data to (rank+i) and receives n/p amount of data from (rank-i). Cost = (p-1).alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma For short messages, we use a recursive doubling algorithm, which takes lgp steps. At step 1, processes exchange (n-n/p) amount of data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p) amount of data, and so forth. Cost = lgp.alpha + n.(lgp-(p-1)/p).beta + n.(lgp-(p-1)/p).gamma Possible improvements: End Algorithm: MPI_Reduce_scatter*//* begin:nested */PMPI_LOCAL int MPIR_Reduce_scatter ( void *sendbuf, void *recvbuf, int *recvcnts, MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr ){ int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; int *displs; void *tmp_recvbuf, *tmp_results; int mpi_errno = MPI_SUCCESS; int type_size, dis[2], blklens[2], total_count, nbytes, src, dst; int mask, dst_tree_root, my_tree_root, j, k; MPI_Datatype sendtype, recvtype; int nprocs_completed, tmp_mask, tree_root, is_commutative; MPI_User_function *uop; MPID_Op *op_ptr; MPI_Status status; MPI_Comm comm; MPICH_PerThread_t *p; comm = comm_ptr->handle; comm_size = comm_ptr->local_size; rank = comm_ptr->rank; /* set op_errno to 0. stored in perthread structure */ MPID_GetPerThread(p); p->op_errno = 0; MPID_Datatype_get_size_macro(datatype, type_size); MPID_Datatype_get_extent_macro(datatype, extent); mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent); if (mpi_errno) return mpi_errno; if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) { is_commutative = 1; /* get the function by indexing into the op table */ uop = MPIR_Op_table[op%16 - 1]; } else { MPID_Op_get_ptr(op, op_ptr); if (op_ptr->kind == MPID_OP_USER_NONCOMMUTE) is_commutative = 0; else is_commutative = 1;#ifdef HAVE_CXX_BINDING if ((op_ptr->language == MPID_LANG_C) || (op_ptr->language == MPID_LANG_CXX)) #else if ((op_ptr->language == MPID_LANG_C))#endif uop = (MPI_User_function *) op_ptr->function.c_function; else uop = (MPI_User_function *) op_ptr->function.f77_function; } displs = MPIU_Malloc(comm_size*sizeof(int)); if (!displs) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } total_count = 0; for (i=0; i<comm_size; i++) { displs[i] = total_count; total_count += recvcnts[i]; } nbytes = total_count * type_size; /* Lock for collective operation */ MPID_Comm_thread_lock( comm_ptr ); if (nbytes > MPIR_REDUCE_SCATTER_SHORT_MSG) { /* for long messages, use (p-1) pairwise exchanges */ if (sendbuf != MPI_IN_PLACE) { /* copy local data into recvbuf */ mpi_errno = MPIR_Localcopy(((char *)sendbuf+displs[rank]*extent), recvcnts[rank], datatype, recvbuf, recvcnts[rank], datatype); if (mpi_errno) return mpi_errno; } /* allocate temporary buffer to store incoming data */ tmp_recvbuf = MPIU_Malloc(true_extent*recvcnts[rank]); if (!tmp_recvbuf) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* adjust for potential negative lower bound in datatype */ tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb); for (i=1; i<comm_size; i++) { src = (rank - i + comm_size) % comm_size; dst = (rank + i) % comm_size; /* send the data that dst needs. recv data that this process needs from src into tmp_recvbuf */ if (sendbuf != MPI_IN_PLACE) mpi_errno = MPIC_Sendrecv(((char *)sendbuf+displs[dst]*extent), recvcnts[dst], datatype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, recvcnts[rank], datatype, src, MPIR_REDUCE_SCATTER_TAG, comm, &status); else mpi_errno = MPIC_Sendrecv(((char *)recvbuf+displs[dst]*extent), recvcnts[dst], datatype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, recvcnts[rank], datatype, src, MPIR_REDUCE_SCATTER_TAG, comm, &status); if (mpi_errno) return mpi_errno; if (is_commutative || (src < rank)) { if (sendbuf != MPI_IN_PLACE) (*uop)(tmp_recvbuf, recvbuf, &recvcnts[rank], &datatype); else { (*uop)(tmp_recvbuf, ((char *)recvbuf+displs[rank]*extent), &recvcnts[rank], &datatype); /* we can't store the result at the beginning of recvbuf right here because there is useful data there that other process/processes need. at the end, we will copy back the result to the beginning of recvbuf. */ } } else { if (sendbuf != MPI_IN_PLACE) { (*uop)(recvbuf, tmp_recvbuf, &recvcnts[rank], &datatype); /* copy result back into recvbuf */ mpi_errno = MPIR_Localcopy(tmp_recvbuf, recvcnts[rank], datatype, recvbuf, recvcnts[rank], datatype); } else { (*uop)(((char *)recvbuf+displs[rank]*extent), tmp_recvbuf, &recvcnts[rank], &datatype); /* copy result back into recvbuf */ mpi_errno = MPIR_Localcopy(tmp_recvbuf, recvcnts[rank], datatype, ((char *)recvbuf + displs[rank]*extent), recvcnts[rank], datatype); } if (mpi_errno) return mpi_errno; } } MPIU_Free((char *)tmp_recvbuf+true_lb); /* if MPI_IN_PLACE, move output data to the beginning of recvbuf. already done for rank 0. */ if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) { mpi_errno = MPIR_Localcopy(((char *)recvbuf + displs[rank]*extent), recvcnts[rank], datatype, recvbuf, recvcnts[rank], datatype); if (mpi_errno) return mpi_errno; } } else { /* for short messages, use recursive doubling. */ /* need to allocate temporary buffer to receive incoming data*/ tmp_recvbuf = MPIU_Malloc(true_extent*total_count); if (!tmp_recvbuf) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* adjust for potential negative lower bound in datatype */ tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb); /* need to allocate another temporary buffer to accumulate results */ tmp_results = MPIU_Malloc(true_extent*total_count); if (!tmp_results) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* adjust for potential negative lower bound in datatype */ tmp_results = (void *)((char*)tmp_results - true_lb); /* copy sendbuf into tmp_results */ if (sendbuf != MPI_IN_PLACE) mpi_errno = MPIR_Localcopy(sendbuf, total_count, datatype, tmp_results, total_count, datatype); else mpi_errno = MPIR_Localcopy(recvbuf, total_count, datatype, tmp_results, total_count, datatype); if (mpi_errno) return mpi_errno; mask = 0x1; i = 0; while (mask < comm_size) { dst = rank ^ mask; dst_tree_root = dst >> i; dst_tree_root <<= i; my_tree_root = rank >> i; my_tree_root <<= i; /* At step 1, processes exchange (n-n/p) amount of data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p) amount of data, and so forth. We use derived datatypes for this. At each step, a process does not need to send data indexed from my_tree_root to my_tree_root+mask-1. Similarly, a process won't receive data indexed from dst_tree_root to dst_tree_root+mask-1. */ /* calculate sendtype */ blklens[0] = blklens[1] = 0; for (j=0; j<my_tree_root; j++) blklens[0] += recvcnts[j]; for (j=my_tree_root+mask; j<comm_size; j++) blklens[1] += recvcnts[j]; dis[0] = 0; dis[1] = blklens[0]; for (j=my_tree_root; j<my_tree_root+mask; j++) dis[1] += recvcnts[j]; NMPI_Type_indexed(2, blklens, dis, datatype, &sendtype); NMPI_Type_commit(&sendtype); /* calculate recvtype */ blklens[0] = blklens[1] = 0; for (j=0; j<dst_tree_root; j++) blklens[0] += recvcnts[j]; for (j=dst_tree_root+mask; j<comm_size; j++) blklens[1] += recvcnts[j]; dis[0] = 0; dis[1] = blklens[0]; for (j=dst_tree_root; j<dst_tree_root+mask; j++) dis[1] += recvcnts[j]; NMPI_Type_indexed(2, blklens, dis, datatype, &recvtype); NMPI_Type_commit(&recvtype); if (dst < comm_size) { /* tmp_results contains data to be sent in each step. Data is received in tmp_recvbuf and then accumulated into tmp_results. */ mpi_errno = MPIC_Sendrecv(tmp_results, 1, sendtype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, 1, recvtype, dst, MPIR_REDUCE_SCATTER_TAG, comm, &status); if (mpi_errno) return mpi_errno; if (is_commutative || (dst_tree_root < my_tree_root)) { (*uop)(tmp_recvbuf, tmp_results, &blklens[0], &datatype); (*uop)(((char *)tmp_recvbuf + dis[1]*extent), ((char *)tmp_results + dis[1]*extent), &blklens[1], &datatype); } else { (*uop)(tmp_results, tmp_recvbuf, &blklens[0], &datatype); (*uop)(((char *)tmp_results + dis[1]*extent), ((char *)tmp_recvbuf + dis[1]*extent), &blklens[1], &datatype); /* copy result back into tmp_results */ mpi_errno = MPIC_Sendrecv(tmp_recvbuf, 1, recvtype, rank, MPIR_REDUCE_SCATTER_TAG, tmp_results, 1, recvtype, rank, MPIR_REDUCE_SCATTER_TAG, comm, &status); if (mpi_errno) return mpi_errno; } } /* if some processes in this process's subtree in this step did not have any destination process to communicate with because of non-power-of-two, we need to send them the result. We use a logarithmic recursive-halfing algorithm for this. */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -