📄 allreduce.c
字号:
if (newrank < newdst) { send_idx = recv_idx + pof2/(mask*2); for (i=send_idx; i<last_idx; i++) send_cnt += cnts[i]; for (i=recv_idx; i<send_idx; i++) recv_cnt += cnts[i]; } else { recv_idx = send_idx + pof2/(mask*2); for (i=send_idx; i<recv_idx; i++) send_cnt += cnts[i]; for (i=recv_idx; i<last_idx; i++) recv_cnt += cnts[i]; }/* printf("Rank %d, send_idx %d, recv_idx %d, send_cnt %d, recv_cnt %d, last_idx %d\n", newrank, send_idx, recv_idx, send_cnt, recv_cnt, last_idx); */ /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIC_Sendrecv((char *) recvbuf + disps[send_idx]*extent, send_cnt, datatype, dst, MPIR_ALLREDUCE_TAG, (char *) tmp_buf + disps[recv_idx]*extent, recv_cnt, datatype, dst, MPIR_ALLREDUCE_TAG, comm, MPI_STATUS_IGNORE); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); /* tmp_buf contains data received in this step. recvbuf contains data accumulated so far */ /* This algorithm is used only for predefined ops and predefined ops are always commutative. */ (*uop)((char *) tmp_buf + disps[recv_idx]*extent, (char *) recvbuf + disps[recv_idx]*extent, &recv_cnt, &datatype); /* update send_idx for next iteration */ send_idx = recv_idx; mask <<= 1; /* update last_idx, but not in last iteration because the value is needed in the allgather step below. */ if (mask < pof2) last_idx = recv_idx + pof2/mask; } /* now do the allgather */ mask >>= 1; while (mask > 0) { newdst = newrank ^ mask; /* find real rank of dest */ dst = (newdst < rem) ? newdst*2 + 1 : newdst + rem; send_cnt = recv_cnt = 0; if (newrank < newdst) { /* update last_idx except on first iteration */ if (mask != pof2/2) last_idx = last_idx + pof2/(mask*2); recv_idx = send_idx + pof2/(mask*2); for (i=send_idx; i<recv_idx; i++) send_cnt += cnts[i]; for (i=recv_idx; i<last_idx; i++) recv_cnt += cnts[i]; } else { recv_idx = send_idx - pof2/(mask*2); for (i=send_idx; i<last_idx; i++) send_cnt += cnts[i]; for (i=recv_idx; i<send_idx; i++) recv_cnt += cnts[i]; } mpi_errno = MPIC_Sendrecv((char *) recvbuf + disps[send_idx]*extent, send_cnt, datatype, dst, MPIR_ALLREDUCE_TAG, (char *) recvbuf + disps[recv_idx]*extent, recv_cnt, datatype, dst, MPIR_ALLREDUCE_TAG, comm, MPI_STATUS_IGNORE); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); if (newrank > newdst) send_idx = recv_idx; mask >>= 1; } } } /* In the non-power-of-two case, all odd-numbered processes of rank < 2*rem send the result to (rank-1), the ranks who didn't participate above. */ if (rank < 2*rem) { if (rank % 2) /* odd */ mpi_errno = MPIC_Send(recvbuf, count, datatype, rank-1, MPIR_ALLREDUCE_TAG, comm); else /* even */ mpi_errno = MPIC_Recv(recvbuf, count, datatype, rank+1, MPIR_ALLREDUCE_TAG, comm, MPI_STATUS_IGNORE); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); } /* check if multiple threads are calling this collective function */ MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr ); if (p->op_errno) mpi_errno = p->op_errno; } fn_exit: MPIU_CHKLMEM_FREEALL(); MPIR_Nest_decr(); return (mpi_errno); fn_fail: goto fn_exit;}/* not declared static because a machine-specific function may call this one in some cases */int MPIR_Allreduce_inter ( void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr ){/* Intercommunicator Allreduce. We first do an intercommunicator reduce to rank 0 on left group, then an intercommunicator reduce to rank 0 on right group, followed by local intracommunicator broadcasts in each group. We don't do local reduces first and then intercommunicator broadcasts because it would require allocation of a temporary buffer. */ static const char FCNAME[] = "MPIR_Allreduce_inter"; int rank, mpi_errno, root; MPID_Comm *newcomm_ptr = NULL; MPIR_Nest_incr(); rank = comm_ptr->rank; /* first do a reduce from right group to rank 0 in left group, then from left group to rank 0 in right group*/ if (comm_ptr->is_low_group) { /* reduce from right group to rank 0*/ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_inter(sendbuf, recvbuf, count, datatype, op, root, comm_ptr); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Reduce_inter(sendbuf, recvbuf, count, datatype, op, root, comm_ptr); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = MPIR_Reduce_inter(sendbuf, recvbuf, count, datatype, op, root, comm_ptr); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_inter(sendbuf, recvbuf, count, datatype, op, root, comm_ptr); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); } /* Get the local intracommunicator */ if (!comm_ptr->local_comm) MPIR_Setup_intercomm_localcomm( comm_ptr ); newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, newcomm_ptr); MPIU_ERR_CHKANDJUMP((mpi_errno), mpi_errno, MPI_ERR_OTHER, "**fail"); fn_exit: MPIR_Nest_decr(); return mpi_errno; fn_fail: goto fn_exit;}#endif#undef FUNCNAME#define FUNCNAME MPI_Allreduce/*@MPI_Allreduce - Combines values from all processes and distributes the result back to all processesInput Parameters:+ sendbuf - starting address of send buffer (choice) . count - number of elements in send buffer (integer) . datatype - data type of elements of send buffer (handle) . op - operation (handle) - comm - communicator (handle) Output Parameter:. recvbuf - starting address of receive buffer (choice) .N ThreadSafe.N Fortran.N collops.N Errors.N MPI_ERR_BUFFER.N MPI_ERR_COUNT.N MPI_ERR_TYPE.N MPI_ERR_OP.N MPI_ERR_COMM@*/int MPI_Allreduce ( void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm ){ static const char FCNAME[] = "MPI_Allreduce"; int mpi_errno = MPI_SUCCESS; MPID_Comm *comm_ptr = NULL; MPID_MPI_STATE_DECL(MPID_STATE_MPI_ALLREDUCE); MPIR_ERRTEST_INITIALIZED_ORDIE(); MPID_CS_ENTER(); MPID_MPI_COLL_FUNC_ENTER(MPID_STATE_MPI_ALLREDUCE); /* Validate parameters, especially handles needing to be converted */# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPIR_ERRTEST_COMM(comm, mpi_errno); if (mpi_errno != MPI_SUCCESS) goto fn_fail; } MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* Convert MPI object handles to object pointers */ MPID_Comm_get_ptr( comm, comm_ptr ); /* Validate parameters and objects (post conversion) */# ifdef HAVE_ERROR_CHECKING { MPID_BEGIN_ERROR_CHECKS; { MPID_Datatype *datatype_ptr = NULL; MPID_Op *op_ptr = NULL; MPID_Comm_valid_ptr( comm_ptr, mpi_errno ); if (mpi_errno != MPI_SUCCESS) goto fn_fail; MPIR_ERRTEST_COUNT(count, mpi_errno); MPIR_ERRTEST_DATATYPE(datatype, "datatype", mpi_errno); MPIR_ERRTEST_OP(op, mpi_errno); if (HANDLE_GET_KIND(datatype) != HANDLE_KIND_BUILTIN) { MPID_Datatype_get_ptr(datatype, datatype_ptr); MPID_Datatype_valid_ptr( datatype_ptr, mpi_errno ); MPID_Datatype_committed_ptr( datatype_ptr, mpi_errno ); } if (comm_ptr->comm_kind == MPID_INTERCOMM) MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, count, mpi_errno); if (sendbuf != MPI_IN_PLACE) MPIR_ERRTEST_USERBUFFER(sendbuf,count,datatype,mpi_errno); MPIR_ERRTEST_RECVBUF_INPLACE(recvbuf, count, mpi_errno); MPIR_ERRTEST_USERBUFFER(recvbuf,count,datatype,mpi_errno); if (mpi_errno != MPI_SUCCESS) goto fn_fail; if (HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) { MPID_Op_get_ptr(op, op_ptr); MPID_Op_valid_ptr( op_ptr, mpi_errno ); } if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) { mpi_errno = ( * MPIR_Op_check_dtype_table[op%16 - 1] )(datatype); } if (mpi_errno != MPI_SUCCESS) goto fn_fail; } MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno); if (mpi_errno != MPI_SUCCESS) goto fn_fail; MPID_END_ERROR_CHECKS; }# endif /* HAVE_ERROR_CHECKING */ /* ... body of routine ... */ if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Allreduce != NULL) { mpi_errno = comm_ptr->coll_fns->Allreduce(sendbuf, recvbuf, count, datatype, op, comm_ptr); } else { if (comm_ptr->comm_kind == MPID_INTRACOMM) /* intracommunicator */ mpi_errno = MPIR_Allreduce(sendbuf, recvbuf, count, datatype, op, comm_ptr); else { /* intercommunicator */ mpi_errno = MPIR_Allreduce_inter(sendbuf, recvbuf, count, datatype, op, comm_ptr); } } if (mpi_errno != MPI_SUCCESS) goto fn_fail; /* ... end of body of routine ... */ fn_exit: MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_ALLREDUCE); MPID_CS_EXIT(); return mpi_errno; fn_fail: /* --BEGIN ERROR HANDLING-- */# ifdef HAVE_ERROR_CHECKING { mpi_errno = MPIR_Err_create_code( mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**mpi_allreduce", "**mpi_allreduce %p %p %d %D %O %C", sendbuf, recvbuf, count, datatype, op, comm); }# endif mpi_errno = MPIR_Err_return_comm( comm_ptr, FCNAME, mpi_errno ); goto fn_exit; /* --END ERROR HANDLING-- */}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -