📄 red_scat.c
字号:
if (dst_tree_root + mask > comm_size) { nprocs_completed = comm_size - my_tree_root - mask; /* nprocs_completed is the number of processes in this subtree that have all the data. Send data to others in a tree fashion. First find root of current tree that is being divided into two. k is the number of least-significant bits in this process's rank that must be zeroed out to find the rank of the root */ j = mask; k = 0; while (j) { j >>= 1; k++; } k--; tmp_mask = mask >> 1; while (tmp_mask) { dst = rank ^ tmp_mask; tree_root = rank >> k; tree_root <<= k; /* send only if this proc has data and destination doesn't have data. at any step, multiple processes can send if they have the data */ if ((dst > rank) && (rank < tree_root + nprocs_completed) && (dst >= tree_root + nprocs_completed)) { /* send the current result */ mpi_errno = MPIC_Send(tmp_recvbuf, 1, recvtype, dst, MPIR_REDUCE_SCATTER_TAG, comm); if (mpi_errno) return mpi_errno; } /* recv only if this proc. doesn't have data and sender has data */ else if ((dst < rank) && (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Recv(tmp_recvbuf, 1, recvtype, dst, MPIR_REDUCE_SCATTER_TAG, comm, &status); if (mpi_errno) return mpi_errno; if (is_commutative || (dst_tree_root < my_tree_root)) { (*uop)(tmp_recvbuf, tmp_results, &blklens[0], &datatype); (*uop)(((char *)tmp_recvbuf + dis[1]*extent), ((char *)tmp_results + dis[1]*extent), &blklens[1], &datatype); } else { (*uop)(tmp_results, tmp_recvbuf, &blklens[0], &datatype); (*uop)(((char *)tmp_results + dis[1]*extent), ((char *)tmp_recvbuf + dis[1]*extent), &blklens[1], &datatype); /* copy result back into tmp_results */ mpi_errno = MPIC_Sendrecv(tmp_recvbuf, 1, recvtype, rank, MPIR_REDUCE_SCATTER_TAG, tmp_results, 1, recvtype, rank, MPIR_REDUCE_SCATTER_TAG, comm, &status); if (mpi_errno) return mpi_errno; } } tmp_mask >>= 1; k--; } } NMPI_Type_free(&sendtype); NMPI_Type_free(&recvtype); mask <<= 1; i++; } /* now copy final results from tmp_results to recvbuf */ mpi_errno = MPIR_Localcopy(((char *)tmp_results+displs[rank]*extent), recvcnts[rank], datatype, recvbuf, recvcnts[rank], datatype); if (mpi_errno) return mpi_errno; MPIU_Free((char *)tmp_recvbuf+true_lb); MPIU_Free((char *)tmp_results+true_lb); } MPIU_Free(displs); /* Unlock for collective operation */ MPID_Comm_thread_unlock( comm_ptr ); if (p->op_errno) mpi_errno = p->op_errno; return (mpi_errno);}/* end:nested *//* begin:nested */PMPI_LOCAL int MPIR_Reduce_scatter_inter ( void *sendbuf, void *recvbuf, int *recvcnts, MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr ){/* Intercommunicator Reduce_scatter. We first do an intercommunicator reduce to rank 0 on left group, then an intercommunicator reduce to rank 0 on right group, followed by local intracommunicator scattervs in each group.*/ int rank, mpi_errno, root, local_size, total_count, i; MPI_Aint true_extent, true_lb; void *tmp_buf=NULL; int *displs=NULL; MPID_Comm *newcomm_ptr = NULL; rank = comm_ptr->rank; local_size = comm_ptr->local_size; total_count = 0; for (i=0; i<local_size; i++) total_count += recvcnts[i]; if (rank == 0) { /* In each group, rank 0 allocates a temp. buffer for the reduce */ displs = MPIU_Malloc(local_size*sizeof(int)); if (!displs) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } total_count = 0; for (i=0; i<local_size; i++) { displs[i] = total_count; total_count += recvcnts[i]; } mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent); if (mpi_errno) return mpi_errno; tmp_buf = MPIU_Malloc(true_extent*total_count); if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* adjust for potential negative lower bound in datatype */ tmp_buf = (void *)((char*)tmp_buf - true_lb); } /* first do a reduce from right group to rank 0 in left group, then from left group to rank 0 in right group*/ if (comm_ptr->is_low_group) { /* reduce from right group to rank 0*/ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, total_count, datatype, op, root, comm_ptr); if (mpi_errno) return mpi_errno; /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, total_count, datatype, op, root, comm_ptr); if (mpi_errno) return mpi_errno; } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, total_count, datatype, op, root, comm_ptr); if (mpi_errno) return mpi_errno; /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, total_count, datatype, op, root, comm_ptr); if (mpi_errno) return mpi_errno; } /* Get the local intracommunicator */ if (!comm_ptr->local_comm) MPIR_Setup_intercomm_localcomm( comm_ptr ); newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Scatterv(tmp_buf, recvcnts, displs, datatype, recvbuf, recvcnts[rank], datatype, 0, newcomm_ptr); if (mpi_errno) return mpi_errno; if (rank == 0) { MPIU_Free(displs); MPIU_Free((char*)tmp_buf+true_lb); } return mpi_errno;}/* end:nested */#endif#undef FUNCNAME#define FUNCNAME MPI_Reduce_scatter/*@ MPI_Reduce_scatter - reduce scatter Arguments:+ void *sendbuf - send buffer. void *recvbuf - receive buffer. int *recvcnts - receive counts. MPI_Datatype datatype - datatype. MPI_Op op - operation- MPI_Comm comm - communicator Notes:.N Fortran.N Errors.N MPI_SUCCESS@*/int MPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcnts, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm){ static const char FCNAME[] = "MPI_Reduce_scatter"; int mpi_errno = MPI_SUCCESS; MPID_Comm *comm_ptr = NULL; MPID_MPI_STATE_DECL(MPID_STATE_MPI_REDUCE_SCATTER); MPID_MPI_COLL_FUNC_ENTER(MPID_STATE_MPI_REDUCE_SCATTER); /* Verify that MPI has been initialized */# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPIR_ERRTEST_INITIALIZED(mpi_errno); MPIR_ERRTEST_COMM(comm, mpi_errno); if (mpi_errno != MPI_SUCCESS) { return MPIR_Err_return_comm( 0, FCNAME, mpi_errno ); } } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* Get handles to MPI objects. */ MPID_Comm_get_ptr( comm, comm_ptr );# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPID_Datatype *datatype_ptr = NULL; MPID_Op *op_ptr = NULL; int rank, i, size; MPID_Comm_valid_ptr( comm_ptr, mpi_errno ); if (mpi_errno != MPI_SUCCESS) { MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_REDUCE_SCATTER); return MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno ); } /* FIXME: Intracomm collective (MPI-1) only */ size = comm_ptr->local_size; for (i=0; i<size; i++) { MPIR_ERRTEST_COUNT(recvcnts[i],mpi_errno); } rank = comm_ptr->rank; MPIR_ERRTEST_DATATYPE(recvcnts[rank], datatype, mpi_errno); MPIR_ERRTEST_OP(op, mpi_errno); if (HANDLE_GET_KIND(datatype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(datatype, datatype_ptr); MPID_Datatype_valid_ptr( datatype_ptr, mpi_errno ); } if (HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) { MPID_Op_get_ptr(op, op_ptr); MPID_Op_valid_ptr( op_ptr, mpi_errno ); } if (mpi_errno != MPI_SUCCESS) { MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_REDUCE_SCATTER); return MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno ); } } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* ... body of routine ... */ if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Reduce_scatter != NULL) { mpi_errno = comm_ptr->coll_fns->Reduce_scatter(sendbuf, recvbuf, recvcnts, datatype, op, comm_ptr); } else { MPIR_Nest_incr(); if (comm_ptr->comm_kind == MPID_INTRACOMM) /* intracommunicator */ mpi_errno = MPIR_Reduce_scatter(sendbuf, recvbuf, recvcnts, datatype, op, comm_ptr); else { /* intercommunicator */ mpi_errno = MPIR_Err_create_code( MPI_ERR_COMM, "**intercommcoll", "**intercommcoll %s", FCNAME ); /*mpi_errno = MPIR_Reduce_scatter_inter(sendbuf, recvbuf, recvcnts, datatype, op, comm_ptr); */ } MPIR_Nest_decr(); } if (mpi_errno == MPI_SUCCESS) { MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_REDUCE_SCATTER); return MPI_SUCCESS; } else { MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_REDUCE_SCATTER); return MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno ); } /* ... end of body of routine ... */}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -