📄 scatter.c
字号:
/* -*- Mode: C; c-basic-offset:4 ; -*- *//* * * (C) 2001 by Argonne National Laboratory. * See COPYRIGHT in top-level directory. */#include "mpiimpl.h"/* -- Begin Profiling Symbol Block for routine MPI_Scatter */#if defined(HAVE_PRAGMA_WEAK)#pragma weak MPI_Scatter = PMPI_Scatter#elif defined(HAVE_PRAGMA_HP_SEC_DEF)#pragma _HP_SECONDARY_DEF PMPI_Scatter MPI_Scatter#elif defined(HAVE_PRAGMA_CRI_DUP)#pragma _CRI duplicate MPI_Scatter as PMPI_Scatter#endif/* -- End Profiling Symbol Block *//* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build the MPI routines */#ifndef MPICH_MPI_FROM_PMPI#define MPI_Scatter PMPI_Scatter/* This is the default implementation of scatter. The algorithm is: Algorithm: MPI_Scatter We use a binomial tree algorithm for both short and long messages. At nodes other than leaf nodes we need to allocate a temporary buffer to store the incoming message. If the root is not rank 0, we reorder the sendbuf in order of relative ranks by copying it into a temporary buffer, so that all the sends from the root are contiguous and in the right order. In the heterogeneous case, we first pack the buffer by using MPI_Pack and then do the scatter. Cost = lgp.alpha + n.((p-1)/p).beta where n is the total size of the data to be scattered from the root. Possible improvements: End Algorithm: MPI_Scatter*//* begin:nested *//* not declared static because a machine-specific function may call this one in some cases */int MPIR_Scatter ( void *sendbuf, int sendcnt, MPI_Datatype sendtype, void *recvbuf, int recvcnt, MPI_Datatype recvtype, int root, MPID_Comm *comm_ptr ){ static const char FCNAME[] = "MPIR_Scatter"; MPI_Status status; MPI_Aint extent=0; int rank, comm_size, is_homogeneous, sendtype_size; int curr_cnt, relative_rank, nbytes, send_subtree_cnt; int mask, recvtype_size=0, src, dst, position;#ifdef MPID_HAS_HETERO int tmp_buf_size;#endif void *tmp_buf=NULL; int mpi_errno = MPI_SUCCESS; MPI_Comm comm; comm = comm_ptr->handle; comm_size = comm_ptr->local_size; rank = comm_ptr->rank; if ( ((rank == root) && (sendcnt == 0)) || ((rank != root) && (recvcnt == 0)) ) return MPI_SUCCESS; is_homogeneous = 1;#ifdef MPID_HAS_HETERO if (comm_ptr->is_hetero) is_homogeneous = 0;#endif/* Use binomial tree algorithm */ if (rank == root) MPID_Datatype_get_extent_macro(sendtype, extent); relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; /* check if multiple threads are calling this collective function */ MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr ); if (is_homogeneous) { /* communicator is homogeneous */ if (rank == root) { /* We separate the two cases (root and non-root) because in the event of recvbuf=MPI_IN_PLACE on the root, recvcnt and recvtype are not valid */ MPID_Datatype_get_size_macro(sendtype, sendtype_size); nbytes = sendtype_size * sendcnt; } else { MPID_Datatype_get_size_macro(recvtype, recvtype_size); nbytes = recvtype_size * recvcnt; } curr_cnt = 0; /* all even nodes other than root need a temporary buffer to receive data of max size (nbytes*comm_size)/2 */ if (relative_rank && !(relative_rank % 2)) { tmp_buf = MPIU_Malloc((nbytes*comm_size)/2); /* --BEGIN ERROR HANDLING-- */ if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ } /* if the root is not rank 0, we reorder the sendbuf in order of relative ranks and copy it into a temporary buffer, so that all the sends from the root are contiguous and in the right order. */ if (rank == root) { if (root != 0) { tmp_buf = MPIU_Malloc(nbytes*comm_size); /* --BEGIN ERROR HANDLING-- */ if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ position = 0; if (recvbuf != MPI_IN_PLACE) mpi_errno = MPIR_Localcopy(((char *) sendbuf + extent*sendcnt*rank), sendcnt*(comm_size-rank), sendtype, tmp_buf, nbytes*(comm_size-rank), MPI_BYTE); else mpi_errno = MPIR_Localcopy(((char *) sendbuf + extent*sendcnt*(rank+1)), sendcnt*(comm_size-rank-1), sendtype, (char *)tmp_buf + nbytes, nbytes*(comm_size-rank-1), MPI_BYTE); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ mpi_errno = MPIR_Localcopy(sendbuf, sendcnt*rank, sendtype, ((char *) tmp_buf + nbytes*(comm_size-rank)), nbytes*rank, MPI_BYTE); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ curr_cnt = nbytes*comm_size; } else curr_cnt = sendcnt*comm_size; } /* root has all the data; others have zero so far */ mask = 0x1; while (mask < comm_size) { if (relative_rank & mask) { src = rank - mask; if (src < 0) src += comm_size; /* The leaf nodes receive directly into recvbuf because they don't have to forward data to anyone. Others receive data into a temporary buffer. */ if (relative_rank % 2) { mpi_errno = MPIC_Recv(recvbuf, recvcnt, recvtype, src, MPIR_SCATTER_TAG, comm, &status); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ } else { mpi_errno = MPIC_Recv(tmp_buf, mask * recvcnt * recvtype_size, MPI_BYTE, src, MPIR_SCATTER_TAG, comm, &status); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ /* the recv size is larger than what may be sent in some cases. query amount of data actually received */ NMPI_Get_count(&status, MPI_BYTE, &curr_cnt); } break; } mask <<= 1; } /* This process is responsible for all processes that have bits set from the LSB upto (but not including) mask. Because of the "not including", we start by shifting mask back down one. */ mask >>= 1; while (mask > 0) { if (relative_rank + mask < comm_size) { dst = rank + mask; if (dst >= comm_size) dst -= comm_size; if ((rank == root) && (root == 0)) { send_subtree_cnt = curr_cnt - sendcnt * mask; /* mask is also the size of this process's subtree */ mpi_errno = MPIC_Send (((char *)sendbuf + extent * sendcnt * mask), send_subtree_cnt, sendtype, dst, MPIR_SCATTER_TAG, comm); } else { /* non-zero root and others */ send_subtree_cnt = curr_cnt - nbytes*mask; /* mask is also the size of this process's subtree */ mpi_errno = MPIC_Send (((char *)tmp_buf + nbytes*mask), send_subtree_cnt, MPI_BYTE, dst, MPIR_SCATTER_TAG, comm); } /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ curr_cnt -= send_subtree_cnt; } mask >>= 1; } if ((rank == root) && (root == 0) && (recvbuf != MPI_IN_PLACE)) { /* for root=0, put root's data in recvbuf if not MPI_IN_PLACE */ mpi_errno = MPIR_Localcopy ( sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype ); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ } else if (!(relative_rank % 2) && (recvbuf != MPI_IN_PLACE)) { /* for non-zero root and non-leaf nodes, copy from tmp_buf into recvbuf */ mpi_errno = MPIR_Localcopy ( tmp_buf, nbytes, MPI_BYTE, recvbuf, recvcnt, recvtype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ } if (tmp_buf != NULL) MPIU_Free(tmp_buf); } #ifdef MPID_HAS_HETERO else { /* communicator is heterogeneous */ if (rank == root) { NMPI_Pack_size(sendcnt*comm_size, sendtype, comm, &tmp_buf_size); tmp_buf = MPIU_Malloc(tmp_buf_size); /* --BEGIN ERROR HANDLING-- */ if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ /* calculate the value of nbytes, the number of bytes in packed representation that each process receives. We can't accurately calculate that from tmp_buf_size because MPI_Pack_size returns an upper bound on the amount of memory required. (For example, for a single integer, MPICH-1 returns pack_size=12.) Therefore, we actually pack some data into tmp_buf and see by how much 'position' is incremented. */ position = 0; NMPI_Pack(sendbuf, 1, sendtype, tmp_buf, tmp_buf_size, &position, comm); nbytes = position*sendcnt; curr_cnt = nbytes*comm_size; if (root == 0) { if (recvbuf != MPI_IN_PLACE) { position = 0; NMPI_Pack(sendbuf, sendcnt*comm_size, sendtype, tmp_buf, tmp_buf_size, &position, comm); } else { position = nbytes; NMPI_Pack(((char *) sendbuf + extent*sendcnt), sendcnt*(comm_size-1), sendtype, tmp_buf, tmp_buf_size, &position, comm); } } else { if (recvbuf != MPI_IN_PLACE) { position = 0; NMPI_Pack(((char *) sendbuf + extent*sendcnt*rank), sendcnt*(comm_size-rank), sendtype, tmp_buf, tmp_buf_size, &position, comm); } else { position = nbytes; NMPI_Pack(((char *) sendbuf + extent*sendcnt*(rank+1)), sendcnt*(comm_size-rank-1), sendtype, tmp_buf, tmp_buf_size, &position, comm); } NMPI_Pack(sendbuf, sendcnt*rank, sendtype, tmp_buf, tmp_buf_size, &position, comm); } } else { NMPI_Pack_size(recvcnt*(comm_size/2), recvtype, comm, &tmp_buf_size); tmp_buf = MPIU_Malloc(tmp_buf_size); /* --BEGIN ERROR HANDLING-- */ if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ /* calculate nbytes */ position = 0; NMPI_Pack(recvbuf, 1, recvtype, tmp_buf, tmp_buf_size, &position, comm); nbytes = position*recvcnt; curr_cnt = 0; } mask = 0x1; while (mask < comm_size) { if (relative_rank & mask) { src = rank - mask; if (src < 0) src += comm_size; mpi_errno = MPIC_Recv(tmp_buf, mask*nbytes, MPI_BYTE, src, MPIR_SCATTER_TAG, comm, &status); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ /* the recv size is larger than what may be sent in some cases. query amount of data actually received */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -