📄 scatter.c
字号:
src = rank - mask; if (src < 0) src += comm_size; mpi_errno = MPIC_Recv(tmp_buf, mask*nbytes, MPI_BYTE, src, MPIR_SCATTER_TAG, comm, &status); if (mpi_errno) return mpi_errno; /* the recv size is larger than what may be sent in some cases. query amount of data actually received */ NMPI_Get_count(&status, MPI_BYTE, &curr_cnt); break; } mask <<= 1; } /* This process is responsible for all processes that have bits set from the LSB upto (but not including) mask. Because of the "not including", we start by shifting mask back down one. */ mask >>= 1; while (mask > 0) { if (relative_rank + mask < comm_size) { dst = rank + mask; if (dst >= comm_size) dst -= comm_size; send_subtree_cnt = curr_cnt - nbytes * mask; /* mask is also the size of this process's subtree */ mpi_errno = MPIC_Send (((char *)tmp_buf + nbytes*mask), send_subtree_cnt, MPI_BYTE, dst, MPIR_SCATTER_TAG, comm); if (mpi_errno) return mpi_errno; curr_cnt -= send_subtree_cnt; } mask >>= 1; } /* copy local data into recvbuf */ position = 0; if (recvbuf != MPI_IN_PLACE) NMPI_Unpack(tmp_buf, tmp_buf_size, &position, recvbuf, recvcnt, recvtype, comm); MPIU_Free(tmp_buf); } /* Unlock for collective operation */ MPID_Comm_thread_unlock( comm_ptr ); return (mpi_errno);}/* end:nested *//* begin:nested */PMPI_LOCAL int MPIR_Scatter_inter ( void *sendbuf, int sendcnt, MPI_Datatype sendtype, void *recvbuf, int recvcnt, MPI_Datatype recvtype, int root, MPID_Comm *comm_ptr ){/* Intercommunicator scatter. For short messages, root sends to rank 0 in remote group. rank 0 does local intracommunicator scatter (MST). Cost: (lgp+1).alpha + n.((p-1)/p).beta + n.beta For long messages, we use linear scatter to avoid the extra n.beta. Cost: p.alpha + n.beta*/ int rank, local_size, remote_size, mpi_errno=MPI_SUCCESS; int i, nbytes, sendtype_size, recvtype_size; MPI_Status status; MPI_Aint extent, true_extent, true_lb; void *tmp_buf=NULL; MPID_Comm *newcomm_ptr = NULL; MPI_Comm comm; if (root == MPI_PROC_NULL) { /* local processes other than root do nothing */ return MPI_SUCCESS; } comm = comm_ptr->handle; remote_size = comm_ptr->remote_size; local_size = comm_ptr->local_size; if (root == MPI_ROOT) { MPID_Datatype_get_size_macro(sendtype, sendtype_size); nbytes = sendtype_size * sendcnt * remote_size; } else { /* remote side */ MPID_Datatype_get_size_macro(recvtype, recvtype_size); nbytes = recvtype_size * recvcnt * local_size; } if (nbytes < MPIR_SCATTER_SHORT_MSG) { if (root == MPI_ROOT) { /* root sends all data to rank 0 on remote group and returns */ MPID_Comm_thread_lock( comm_ptr ); mpi_errno = MPIC_Send(sendbuf, sendcnt*remote_size, sendtype, 0, MPIR_SCATTER_TAG, comm); MPID_Comm_thread_unlock( comm_ptr ); return mpi_errno; } else { /* remote group. rank 0 receives data from root. need to allocate temporary buffer to store this data. */ rank = comm_ptr->rank; if (rank == 0) { mpi_errno = NMPI_Type_get_true_extent(recvtype, &true_lb, &true_extent); if (mpi_errno) return mpi_errno; tmp_buf = MPIU_Malloc(true_extent*recvcnt*local_size); if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* adjust for potential negative lower bound in datatype */ tmp_buf = (void *)((char*)tmp_buf - true_lb); MPID_Comm_thread_lock( comm_ptr ); mpi_errno = MPIC_Recv(tmp_buf, recvcnt*local_size, recvtype, root, MPIR_SCATTER_TAG, comm, &status); MPID_Comm_thread_unlock( comm_ptr ); if (mpi_errno) return mpi_errno; } /* Get the local intracommunicator */ if (!comm_ptr->local_comm) MPIR_Setup_intercomm_localcomm( comm_ptr ); newcomm_ptr = comm_ptr->local_comm; /* now do the usual scatter on this intracommunicator */ mpi_errno = MPIR_Scatter(tmp_buf, recvcnt, recvtype, recvbuf, recvcnt, recvtype, 0, newcomm_ptr); if (rank == 0) MPIU_Free(((char*)tmp_buf+true_lb)); } } else { /* long message. use linear algorithm. */ MPID_Comm_thread_lock( comm_ptr ); if (root == MPI_ROOT) { MPID_Datatype_get_extent_macro(sendtype, extent); for (i=0; i<remote_size; i++) { mpi_errno = MPIC_Send(((char *)sendbuf+sendcnt*i*extent), sendcnt, sendtype, i, MPIR_SCATTER_TAG, comm); if (mpi_errno) return mpi_errno; } } else { mpi_errno = MPIC_Recv(recvbuf,recvcnt,recvtype,root, MPIR_SCATTER_TAG,comm,&status); } MPID_Comm_thread_unlock( comm_ptr ); } return mpi_errno;}/* end:nested */#endif#undef FUNCNAME#define FUNCNAME MPI_Scatter/*@ MPI_Scatter - scatter Arguments:+ void *sendbuf - send buffer. int sendcnt - send count. MPI_Datatype sendtype - send type. void *recvbuf - receive buffer. int recvcnt - receive count. MPI_Datatype recvtype - receive datatype. int root - root- MPI_Comm comm - communicator Notes:.N Fortran.N Errors.N MPI_SUCCESS@*/int MPI_Scatter(void *sendbuf, int sendcnt, MPI_Datatype sendtype, void *recvbuf, int recvcnt, MPI_Datatype recvtype, int root, MPI_Comm comm){ static const char FCNAME[] = "MPI_Scatter"; int mpi_errno = MPI_SUCCESS; MPID_Comm *comm_ptr = NULL; MPID_MPI_STATE_DECL(MPID_STATE_MPI_SCATTER); MPID_MPI_COLL_FUNC_ENTER(MPID_STATE_MPI_SCATTER); /* Verify that MPI has been initialized */# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPIR_ERRTEST_INITIALIZED(mpi_errno); MPIR_ERRTEST_COMM(comm, mpi_errno); if (mpi_errno != MPI_SUCCESS) { return MPIR_Err_return_comm( 0, FCNAME, mpi_errno ); } } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* Get handles to MPI objects. */ MPID_Comm_get_ptr( comm, comm_ptr );# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPID_Datatype *sendtype_ptr=NULL, *recvtype_ptr=NULL; int rank; MPID_Comm_valid_ptr( comm_ptr, mpi_errno ); if (mpi_errno != MPI_SUCCESS) { MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_SCATTER); return MPIR_Err_return_comm( NULL, FCNAME, mpi_errno ); } rank = comm_ptr->rank; if (rank == root) { MPIR_ERRTEST_COUNT(sendcnt, mpi_errno); MPIR_ERRTEST_DATATYPE(sendcnt, sendtype, mpi_errno); if (HANDLE_GET_KIND(sendtype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(sendtype, sendtype_ptr); MPID_Datatype_valid_ptr( sendtype_ptr, mpi_errno ); } } MPIR_ERRTEST_COUNT(recvcnt, mpi_errno); MPIR_ERRTEST_DATATYPE(recvcnt, recvtype, mpi_errno); MPIR_ERRTEST_INTRA_ROOT(comm_ptr, root, mpi_errno); if (HANDLE_GET_KIND(recvtype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(recvtype, recvtype_ptr); MPID_Datatype_valid_ptr( recvtype_ptr, mpi_errno ); } if (mpi_errno != MPI_SUCCESS) { MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_SCATTER); return MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno ); } } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* ... body of routine ... */ if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Scatter != NULL) { mpi_errno = comm_ptr->coll_fns->Scatter(sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype, root, comm_ptr); } else { MPIR_Nest_incr(); if (comm_ptr->comm_kind == MPID_INTRACOMM) /* intracommunicator */ mpi_errno = MPIR_Scatter(sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype, root, comm_ptr); else { /* intercommunicator */ mpi_errno = MPIR_Err_create_code( MPI_ERR_COMM, "**intercommcoll", "**intercommcoll %s", FCNAME ); /*mpi_errno = MPIR_Scatter_inter(sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype, root, comm_ptr); */ } MPIR_Nest_decr(); } if (mpi_errno == MPI_SUCCESS) { MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_SCATTER); return MPI_SUCCESS; } else { MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_SCATTER); return MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno ); } /* ... end of body of routine ... */}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -