📄 red_scat.c
字号:
*/ static const char FCNAME[] = "MPIR_Reduce_scatter_inter"; int rank, mpi_errno, root, local_size, total_count, i; MPI_Aint true_extent, true_lb = 0, extent; void *tmp_buf=NULL; int *disps=NULL; MPID_Comm *newcomm_ptr = NULL; rank = comm_ptr->rank; local_size = comm_ptr->local_size; total_count = 0; for (i=0; i<local_size; i++) total_count += recvcnts[i]; if (rank == 0) { /* In each group, rank 0 allocates a temp. buffer for the reduce */ disps = MPIU_Malloc(local_size*sizeof(int)); /* --BEGIN ERROR HANDLING-- */ if (!disps) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ total_count = 0; for (i=0; i<local_size; i++) { disps[i] = total_count; total_count += recvcnts[i]; } mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ MPID_Datatype_get_extent_macro(datatype, extent); tmp_buf = MPIU_Malloc(total_count*(MPIR_MAX(extent,true_extent))); /* --BEGIN ERROR HANDLING-- */ if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ /* adjust for potential negative lower bound in datatype */ tmp_buf = (void *)((char*)tmp_buf - true_lb); } /* first do a reduce from right group to rank 0 in left group, then from left group to rank 0 in right group*/ if (comm_ptr->is_low_group) { /* reduce from right group to rank 0*/ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op, root, comm_ptr); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op, root, comm_ptr); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op, root, comm_ptr); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op, root, comm_ptr); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ } /* Get the local intracommunicator */ if (!comm_ptr->local_comm) MPIR_Setup_intercomm_localcomm( comm_ptr ); newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Scatterv(tmp_buf, recvcnts, disps, datatype, recvbuf, recvcnts[rank], datatype, 0, newcomm_ptr); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ if (rank == 0) { MPIU_Free(disps); MPIU_Free((char*)tmp_buf+true_lb); } return mpi_errno;}/* end:nested */#endif#undef FUNCNAME#define FUNCNAME MPI_Reduce_scatter/*@MPI_Reduce_scatter - Combines values and scatters the resultsInput Parameters:+ sendbuf - starting address of send buffer (choice) . recvcounts - integer array specifying the number of elements in result distributed to each process.Array must be identical on all calling processes. . datatype - data type of elements of input buffer (handle) . op - operation (handle) - comm - communicator (handle) Output Parameter:. recvbuf - starting address of receive buffer (choice) .N ThreadSafe.N Fortran.N collops.N Errors.N MPI_SUCCESS.N MPI_ERR_COMM.N MPI_ERR_COUNT.N MPI_ERR_TYPE.N MPI_ERR_BUFFER.N MPI_ERR_OP.N MPI_ERR_BUFFER_ALIAS@*/int MPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcnts, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm){ static const char FCNAME[] = "MPI_Reduce_scatter"; int mpi_errno = MPI_SUCCESS; MPID_Comm *comm_ptr = NULL; MPID_MPI_STATE_DECL(MPID_STATE_MPI_REDUCE_SCATTER); MPIR_ERRTEST_INITIALIZED_ORDIE(); MPID_CS_ENTER(); MPID_MPI_COLL_FUNC_ENTER(MPID_STATE_MPI_REDUCE_SCATTER); /* Validate parameters, especially handles needing to be converted */# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPIR_ERRTEST_COMM(comm, mpi_errno); if (mpi_errno != MPI_SUCCESS) goto fn_fail; } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* Convert MPI object handles to object pointers */ MPID_Comm_get_ptr( comm, comm_ptr ); /* Validate parameters and objects (post conversion) */# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPID_Datatype *datatype_ptr = NULL; MPID_Op *op_ptr = NULL; int i, size, sum; MPID_Comm_valid_ptr( comm_ptr, mpi_errno ); if (mpi_errno != MPI_SUCCESS) goto fn_fail; size = comm_ptr->local_size; /* even in intercomm. case, recvcnts is of size local_size */ sum = 0; for (i=0; i<size; i++) { MPIR_ERRTEST_COUNT(recvcnts[i],mpi_errno); sum += recvcnts[i]; } MPIR_ERRTEST_DATATYPE(datatype, "datatype", mpi_errno); if (HANDLE_GET_KIND(datatype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(datatype, datatype_ptr); MPID_Datatype_valid_ptr( datatype_ptr, mpi_errno ); MPID_Datatype_committed_ptr( datatype_ptr, mpi_errno ); } MPIR_ERRTEST_RECVBUF_INPLACE(recvbuf, recvcnts[comm_ptr->rank], mpi_errno); if (comm_ptr->comm_kind == MPID_INTERCOMM) MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sum, mpi_errno); MPIR_ERRTEST_USERBUFFER(recvbuf,recvcnts[comm_ptr->rank],datatype,mpi_errno); MPIR_ERRTEST_USERBUFFER(sendbuf,sum,datatype,mpi_errno); MPIR_ERRTEST_OP(op, mpi_errno); if (mpi_errno != MPI_SUCCESS) goto fn_fail; if (HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) { MPID_Op_get_ptr(op, op_ptr); MPID_Op_valid_ptr( op_ptr, mpi_errno ); } if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) { mpi_errno = ( * MPIR_Op_check_dtype_table[op%16 - 1] )(datatype); } if (mpi_errno != MPI_SUCCESS) goto fn_fail; } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* ... body of routine ... */ if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Reduce_scatter != NULL) { mpi_errno = comm_ptr->coll_fns->Reduce_scatter(sendbuf, recvbuf, recvcnts, datatype, op, comm_ptr); } else { MPIR_Nest_incr(); if (comm_ptr->comm_kind == MPID_INTRACOMM) /* intracommunicator */ mpi_errno = MPIR_Reduce_scatter(sendbuf, recvbuf, recvcnts, datatype, op, comm_ptr); else { /* intercommunicator */ /* mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_COMM, "**intercommcoll", "**intercommcoll %s", FCNAME ); */ mpi_errno = MPIR_Reduce_scatter_inter(sendbuf, recvbuf, recvcnts, datatype, op, comm_ptr); } MPIR_Nest_decr(); } if (mpi_errno != MPI_SUCCESS) goto fn_fail; /* ... end of body of routine ... */ fn_exit: MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_REDUCE_SCATTER); MPID_CS_EXIT(); return mpi_errno; fn_fail: /* --BEGIN ERROR HANDLING-- */# ifdef HAVE_ERROR_CHECKING { mpi_errno = MPIR_Err_create_code( mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**mpi_reduce_scatter", "**mpi_reduce_scatter %p %p %p %D %O %C", sendbuf, recvbuf, recvcnts, datatype, op, comm); }# endif mpi_errno = MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno ); goto fn_exit; /* --END ERROR HANDLING-- */}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -