📄 scatter.c
字号:
NMPI_Get_count(&status, MPI_BYTE, &curr_cnt); break; } mask <<= 1; } /* This process is responsible for all processes that have bits set from the LSB upto (but not including) mask. Because of the "not including", we start by shifting mask back down one. */ mask >>= 1; while (mask > 0) { if (relative_rank + mask < comm_size) { dst = rank + mask; if (dst >= comm_size) dst -= comm_size; send_subtree_cnt = curr_cnt - nbytes * mask; /* mask is also the size of this process's subtree */ mpi_errno = MPIC_Send (((char *)tmp_buf + nbytes*mask), send_subtree_cnt, MPI_BYTE, dst, MPIR_SCATTER_TAG, comm); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ curr_cnt -= send_subtree_cnt; } mask >>= 1; } /* copy local data into recvbuf */ position = 0; if (recvbuf != MPI_IN_PLACE) NMPI_Unpack(tmp_buf, tmp_buf_size, &position, recvbuf, recvcnt, recvtype, comm); MPIU_Free(tmp_buf); }#endif /* MPID_HAS_HETERO */ /* check if multiple threads are calling this collective function */ MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr ); return (mpi_errno);}/* end:nested *//* begin:nested *//* not declared static because a machine-specific function may call this one in some cases */int MPIR_Scatter_inter ( void *sendbuf, int sendcnt, MPI_Datatype sendtype, void *recvbuf, int recvcnt, MPI_Datatype recvtype, int root, MPID_Comm *comm_ptr ){/* Intercommunicator scatter. For short messages, root sends to rank 0 in remote group. rank 0 does local intracommunicator scatter (binomial tree). Cost: (lgp+1).alpha + n.((p-1)/p).beta + n.beta For long messages, we use linear scatter to avoid the extra n.beta. Cost: p.alpha + n.beta*/ static const char FCNAME[] = "MPIR_Scatter_inter"; int rank, local_size, remote_size, mpi_errno=MPI_SUCCESS; int i, nbytes, sendtype_size, recvtype_size; MPI_Status status; MPI_Aint extent, true_extent, true_lb = 0; void *tmp_buf=NULL; MPID_Comm *newcomm_ptr = NULL; MPI_Comm comm; if (root == MPI_PROC_NULL) { /* local processes other than root do nothing */ return MPI_SUCCESS; } comm = comm_ptr->handle; remote_size = comm_ptr->remote_size; local_size = comm_ptr->local_size; if (root == MPI_ROOT) { MPID_Datatype_get_size_macro(sendtype, sendtype_size); nbytes = sendtype_size * sendcnt * remote_size; } else { /* remote side */ MPID_Datatype_get_size_macro(recvtype, recvtype_size); nbytes = recvtype_size * recvcnt * local_size; } if (nbytes < MPIR_SCATTER_SHORT_MSG) { if (root == MPI_ROOT) { /* root sends all data to rank 0 on remote group and returns */ MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr ); mpi_errno = MPIC_Send(sendbuf, sendcnt*remote_size, sendtype, 0, MPIR_SCATTER_TAG, comm); MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr ); return mpi_errno; } else { /* remote group. rank 0 receives data from root. need to allocate temporary buffer to store this data. */ rank = comm_ptr->rank; if (rank == 0) { mpi_errno = NMPI_Type_get_true_extent(recvtype, &true_lb, &true_extent); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ MPID_Datatype_get_extent_macro(recvtype, extent); tmp_buf = MPIU_Malloc(recvcnt*local_size*(MPIR_MAX(extent,true_extent))); /* --BEGIN ERROR HANDLING-- */ if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ /* adjust for potential negative lower bound in datatype */ tmp_buf = (void *)((char*)tmp_buf - true_lb); MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr ); mpi_errno = MPIC_Recv(tmp_buf, recvcnt*local_size, recvtype, root, MPIR_SCATTER_TAG, comm, &status); MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr ); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ } /* Get the local intracommunicator */ if (!comm_ptr->local_comm) MPIR_Setup_intercomm_localcomm( comm_ptr ); newcomm_ptr = comm_ptr->local_comm; /* now do the usual scatter on this intracommunicator */ mpi_errno = MPIR_Scatter(tmp_buf, recvcnt, recvtype, recvbuf, recvcnt, recvtype, 0, newcomm_ptr); if (rank == 0) MPIU_Free(((char*)tmp_buf+true_lb)); } } else { /* long message. use linear algorithm. */ MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr ); if (root == MPI_ROOT) { MPID_Datatype_get_extent_macro(sendtype, extent); for (i=0; i<remote_size; i++) { mpi_errno = MPIC_Send(((char *)sendbuf+sendcnt*i*extent), sendcnt, sendtype, i, MPIR_SCATTER_TAG, comm); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ } } else { mpi_errno = MPIC_Recv(recvbuf,recvcnt,recvtype,root, MPIR_SCATTER_TAG,comm,&status); } MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr ); } return mpi_errno;}/* end:nested */#endif#undef FUNCNAME#define FUNCNAME MPI_Scatter/*@MPI_Scatter - Sends data from one process to all other processes in a communicatorInput Parameters:+ sendbuf - address of send buffer (choice, significant only at 'root') . sendcount - number of elements sent to each process (integer, significant only at 'root') . sendtype - data type of send buffer elements (significant only at 'root') (handle) . recvcount - number of elements in receive buffer (integer) . recvtype - data type of receive buffer elements (handle) . root - rank of sending process (integer) - comm - communicator (handle) Output Parameter:. recvbuf - address of receive buffer (choice) .N ThreadSafe.N Fortran.N Errors.N MPI_SUCCESS.N MPI_ERR_COMM.N MPI_ERR_COUNT.N MPI_ERR_TYPE.N MPI_ERR_BUFFER@*/int MPI_Scatter(void *sendbuf, int sendcnt, MPI_Datatype sendtype, void *recvbuf, int recvcnt, MPI_Datatype recvtype, int root, MPI_Comm comm){ static const char FCNAME[] = "MPI_Scatter"; int mpi_errno = MPI_SUCCESS; MPID_Comm *comm_ptr = NULL; MPID_MPI_STATE_DECL(MPID_STATE_MPI_SCATTER); MPIR_ERRTEST_INITIALIZED_ORDIE(); MPID_CS_ENTER(); MPID_MPI_COLL_FUNC_ENTER(MPID_STATE_MPI_SCATTER); /* Validate parameters, especially handles needing to be converted */# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPIR_ERRTEST_COMM(comm, mpi_errno); if (mpi_errno != MPI_SUCCESS) goto fn_fail; } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* Convert MPI object handles to object pointers */ MPID_Comm_get_ptr( comm, comm_ptr ); /* Validate parameters and objects (post conversion) */# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPID_Datatype *sendtype_ptr=NULL, *recvtype_ptr=NULL; int rank; MPID_Comm_valid_ptr( comm_ptr, mpi_errno ); if (mpi_errno != MPI_SUCCESS) goto fn_fail; if (comm_ptr->comm_kind == MPID_INTRACOMM) { MPIR_ERRTEST_INTRA_ROOT(comm_ptr, root, mpi_errno); rank = comm_ptr->rank; if (rank == root) { MPIR_ERRTEST_COUNT(sendcnt, mpi_errno); MPIR_ERRTEST_DATATYPE(sendtype, "sendtype", mpi_errno); if (HANDLE_GET_KIND(sendtype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(sendtype, sendtype_ptr); MPID_Datatype_valid_ptr( sendtype_ptr, mpi_errno ); MPID_Datatype_committed_ptr( sendtype_ptr, mpi_errno ); } MPIR_ERRTEST_USERBUFFER(sendbuf,sendcnt,sendtype,mpi_errno); MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcnt, mpi_errno); } else MPIR_ERRTEST_RECVBUF_INPLACE(recvbuf, recvcnt, mpi_errno); if (recvbuf != MPI_IN_PLACE) { MPIR_ERRTEST_COUNT(recvcnt, mpi_errno); MPIR_ERRTEST_DATATYPE(recvtype, "recvtype", mpi_errno); if (HANDLE_GET_KIND(recvtype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(recvtype, recvtype_ptr); MPID_Datatype_valid_ptr( recvtype_ptr, mpi_errno ); MPID_Datatype_committed_ptr( recvtype_ptr, mpi_errno ); } MPIR_ERRTEST_USERBUFFER(recvbuf,recvcnt,recvtype,mpi_errno); } } if (comm_ptr->comm_kind == MPID_INTERCOMM) { MPIR_ERRTEST_INTER_ROOT(comm_ptr, root, mpi_errno); if (root == MPI_ROOT) { MPIR_ERRTEST_COUNT(sendcnt, mpi_errno); MPIR_ERRTEST_DATATYPE(sendtype, "sendtype", mpi_errno); if (HANDLE_GET_KIND(sendtype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(sendtype, sendtype_ptr); MPID_Datatype_valid_ptr( sendtype_ptr, mpi_errno ); MPID_Datatype_committed_ptr( sendtype_ptr, mpi_errno ); } MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcnt, mpi_errno); MPIR_ERRTEST_USERBUFFER(sendbuf,sendcnt,sendtype,mpi_errno); } else if (root != MPI_PROC_NULL) { MPIR_ERRTEST_COUNT(recvcnt, mpi_errno); MPIR_ERRTEST_DATATYPE(recvtype, "recvtype", mpi_errno); if (HANDLE_GET_KIND(recvtype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(recvtype, recvtype_ptr); MPID_Datatype_valid_ptr( recvtype_ptr, mpi_errno ); MPID_Datatype_committed_ptr( recvtype_ptr, mpi_errno ); } MPIR_ERRTEST_RECVBUF_INPLACE(recvbuf, recvcnt, mpi_errno); MPIR_ERRTEST_USERBUFFER(recvbuf,recvcnt,recvtype,mpi_errno); } } if (mpi_errno != MPI_SUCCESS) goto fn_fail; } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* ... body of routine ... */ if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Scatter != NULL) { mpi_errno = comm_ptr->coll_fns->Scatter(sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype, root, comm_ptr); } else { MPIR_Nest_incr(); if (comm_ptr->comm_kind == MPID_INTRACOMM) /* intracommunicator */ mpi_errno = MPIR_Scatter(sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype, root, comm_ptr); else { /* intercommunicator */ /* mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_COMM, "**intercommcoll", "**intercommcoll %s", FCNAME );*/ mpi_errno = MPIR_Scatter_inter(sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype, root, comm_ptr); } MPIR_Nest_decr(); } if (mpi_errno != MPI_SUCCESS) goto fn_fail; /* ... end of body of routine ... */ fn_exit: MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_SCATTER); MPID_CS_EXIT(); return mpi_errno; fn_fail: /* --BEGIN ERROR HANDLING-- */# ifdef HAVE_ERROR_CHECKING { mpi_errno = MPIR_Err_create_code( mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**mpi_scatter", "**mpi_scatter %p %d %D %p %d %D %d %C", sendbuf, sendcnt, sendtype, recvbuf, recvcnt, recvtype, root, comm); }# endif mpi_errno = MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno ); goto fn_exit; /* --END ERROR HANDLING-- */}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -