📄 scatter.c
字号:
/* -*- Mode: C; c-basic-offset:4 ; -*- *//* * * (C) 2001 by Argonne National Laboratory. * See COPYRIGHT in top-level directory. */#include "mpiimpl.h"/* -- Begin Profiling Symbol Block for routine MPI_Scatter */#if defined(HAVE_PRAGMA_WEAK)#pragma weak MPI_Scatter = PMPI_Scatter#elif defined(HAVE_PRAGMA_HP_SEC_DEF)#pragma _HP_SECONDARY_DEF PMPI_Scatter MPI_Scatter#elif defined(HAVE_PRAGMA_CRI_DUP)#pragma _CRI duplicate MPI_Scatter as PMPI_Scatter#endif/* -- End Profiling Symbol Block *//* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build the MPI routines */#ifndef MPICH_MPI_FROM_PMPI#define MPI_Scatter PMPI_Scatter/* This is the default implementation of scatter. The algorithm is: Algorithm: MPI_Scatter We use a minimum spanning tree (MST) algorithm for both short and long messages. At nodes other than leaf nodes we need to allocate a temporary buffer to store the incoming message. If the root is not rank 0, we reorder the sendbuf in order of relative ranks by copying it into a temporary buffer, so that all the sends from the root are contiguous and in the right order. In the heterogeneous case, we first pack the buffer by using MPI_Pack and then do the scatter. Cost = lgp.alpha + n.((p-1)/p).beta where n is the total size of the data to be scattered from the root. Possible improvements: End Algorithm: MPI_Scatter*//* begin:nested */PMPI_LOCAL int MPIR_Scatter ( void *sendbuf, int sendcnt, MPI_Datatype sendtype, void *recvbuf, int recvcnt, MPI_Datatype recvtype, int root, MPID_Comm *comm_ptr ){ MPI_Status status; MPI_Aint extent=0; int rank, comm_size, is_homogeneous, sendtype_size; int curr_cnt, relative_rank, nbytes, send_subtree_cnt; int mask, recvtype_size=0, src, dst, position, tmp_buf_size; void *tmp_buf=NULL; int mpi_errno = MPI_SUCCESS; MPI_Comm comm; if (recvcnt == 0) return MPI_SUCCESS; comm = comm_ptr->handle; comm_size = comm_ptr->local_size; rank = comm_ptr->rank; is_homogeneous = 1;#ifdef MPID_HAS_HETERO if (comm_ptr->is_hetero) is_homogeneous = 0;#endif/* Use MST algorithm */ if (rank == root) MPID_Datatype_get_extent_macro(sendtype, extent); relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; /* Lock for collective operation */ MPID_Comm_thread_lock( comm_ptr ); if (is_homogeneous) { /* communicator is homogeneous */ if (rank == root) { /* We separate the two cases (root and non-root) because in the event of recvbuf=MPI_IN_PLACE on the root, recvcnt and recvtype are not valid */ MPID_Datatype_get_size_macro(sendtype, sendtype_size); nbytes = sendtype_size * sendcnt; } else { MPID_Datatype_get_size_macro(recvtype, recvtype_size); nbytes = recvtype_size * recvcnt; } curr_cnt = 0; /* all even nodes other than root need a temporary buffer to receive data of max size (nbytes*comm_size)/2 */ if (relative_rank && !(relative_rank % 2)) { tmp_buf = MPIU_Malloc((nbytes*comm_size)/2); if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } } /* if the root is not rank 0, we reorder the sendbuf in order of relative ranks and copy it into a temporary buffer, so that all the sends from the root are contiguous and in the right order. */ if (rank == root) { if (root != 0) { tmp_buf = MPIU_Malloc(nbytes*comm_size); if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } position = 0; if (recvbuf != MPI_IN_PLACE) MPIR_Localcopy(((char *) sendbuf + extent*sendcnt*rank), sendcnt*(comm_size-rank), sendtype, tmp_buf, nbytes*(comm_size-rank), MPI_BYTE); else MPIR_Localcopy(((char *) sendbuf + extent*sendcnt*(rank+1)), sendcnt*(comm_size-rank-1), sendtype, tmp_buf, nbytes*(comm_size-rank-1), MPI_BYTE); MPIR_Localcopy(sendbuf, sendcnt*rank, sendtype, ((char *) tmp_buf + nbytes*(comm_size-rank)), nbytes*rank, MPI_BYTE); curr_cnt = nbytes*comm_size; } else curr_cnt = sendcnt*comm_size; } /* root has all the data; others have zero so far */ mask = 0x1; while (mask < comm_size) { if (relative_rank & mask) { src = rank - mask; if (src < 0) src += comm_size; /* The leaf nodes receive directly into recvbuf because they don't have to forward data to anyone. Others receive data into a temporary buffer. */ if (relative_rank % 2) { mpi_errno = MPIC_Recv(recvbuf, recvcnt, recvtype, src, MPIR_SCATTER_TAG, comm, &status); if (mpi_errno) return mpi_errno; } else { mpi_errno = MPIC_Recv(tmp_buf, mask * recvcnt * recvtype_size, MPI_BYTE, src, MPIR_SCATTER_TAG, comm, &status); if (mpi_errno) return mpi_errno; /* the recv size is larger than what may be sent in some cases. query amount of data actually received */ NMPI_Get_count(&status, MPI_BYTE, &curr_cnt); } break; } mask <<= 1; } /* This process is responsible for all processes that have bits set from the LSB upto (but not including) mask. Because of the "not including", we start by shifting mask back down one. */ mask >>= 1; while (mask > 0) { if (relative_rank + mask < comm_size) { dst = rank + mask; if (dst >= comm_size) dst -= comm_size; if ((rank == root) && (root == 0)) { send_subtree_cnt = curr_cnt - sendcnt * mask; /* mask is also the size of this process's subtree */ mpi_errno = MPIC_Send (((char *)sendbuf + extent * sendcnt * mask), send_subtree_cnt, sendtype, dst, MPIR_SCATTER_TAG, comm); } else { /* non-zero root and others */ send_subtree_cnt = curr_cnt - nbytes*mask; /* mask is also the size of this process's subtree */ mpi_errno = MPIC_Send (((char *)tmp_buf + nbytes*mask), send_subtree_cnt, MPI_BYTE, dst, MPIR_SCATTER_TAG, comm); } if (mpi_errno) return mpi_errno; curr_cnt -= send_subtree_cnt; } mask >>= 1; } if ((rank == root) && (root == 0) && (recvbuf != MPI_IN_PLACE)) { /* for root=0, put root's data in recvbuf if not MPI_IN_PLACE */ mpi_errno = MPIR_Localcopy ( sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype ); if (mpi_errno) return mpi_errno; } else if (!(relative_rank % 2) && (recvbuf != MPI_IN_PLACE)) { /* for non-zero root and non-leaf nodes, copy from tmp_buf into recvbuf */ mpi_errno = MPIR_Localcopy ( tmp_buf, nbytes, MPI_BYTE, recvbuf, recvcnt, recvtype); if (mpi_errno) return mpi_errno; MPIU_Free(tmp_buf); } } else { /* communicator is heterogeneous */ if (rank == root) { NMPI_Pack_size(sendcnt*comm_size, sendtype, comm, &tmp_buf_size); tmp_buf = MPIU_Malloc(tmp_buf_size); if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* calculate the value of nbytes, the number of bytes in packed representation that each process receives. We can't accurately calculate that from tmp_buf_size because MPI_Pack_size returns an upper bound on the amount of memory required. (For example, for a single integer, MPICH-1 returns pack_size=12.) Therefore, we actually pack some data into tmp_buf and see by how much 'position' is incremented. */ position = 0; NMPI_Pack(sendbuf, 1, sendtype, tmp_buf, tmp_buf_size, &position, comm); nbytes = position*sendcnt; curr_cnt = nbytes*comm_size; if (root == 0) { if (recvbuf != MPI_IN_PLACE) { position = 0; NMPI_Pack(sendbuf, sendcnt*comm_size, sendtype, tmp_buf, tmp_buf_size, &position, comm); } else { position = nbytes; NMPI_Pack(((char *) sendbuf + extent*sendcnt), sendcnt*(comm_size-1), sendtype, tmp_buf, tmp_buf_size, &position, comm); } } else { if (recvbuf != MPI_IN_PLACE) { position = 0; NMPI_Pack(((char *) sendbuf + extent*sendcnt*rank), sendcnt*(comm_size-rank), sendtype, tmp_buf, tmp_buf_size, &position, comm); } else { position = nbytes; NMPI_Pack(((char *) sendbuf + extent*sendcnt*(rank+1)), sendcnt*(comm_size-rank-1), sendtype, tmp_buf, tmp_buf_size, &position, comm); } NMPI_Pack(sendbuf, sendcnt*rank, sendtype, tmp_buf, tmp_buf_size, &position, comm); } } else { NMPI_Pack_size(recvcnt*(comm_size/2), recvtype, comm, &tmp_buf_size); tmp_buf = MPIU_Malloc(tmp_buf_size); if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* calculate nbytes */ position = 0; NMPI_Pack(recvbuf, 1, recvtype, tmp_buf, tmp_buf_size, &position, comm); nbytes = position*recvcnt; curr_cnt = 0; } mask = 0x1; while (mask < comm_size) { if (relative_rank & mask) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -