📄 alltoall.c
字号:
/* -*- Mode: C; c-basic-offset:4 ; -*- *//* * * (C) 2001 by Argonne National Laboratory. * See COPYRIGHT in top-level directory. */#include "mpiimpl.h"/* -- Begin Profiling Symbol Block for routine MPI_Alltoall */#if defined(HAVE_PRAGMA_WEAK)#pragma weak MPI_Alltoall = PMPI_Alltoall#elif defined(HAVE_PRAGMA_HP_SEC_DEF)#pragma _HP_SECONDARY_DEF PMPI_Alltoall MPI_Alltoall#elif defined(HAVE_PRAGMA_CRI_DUP)#pragma _CRI duplicate MPI_Alltoall as PMPI_Alltoall#endif/* -- End Profiling Symbol Block *//* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build the MPI routines */#ifndef MPICH_MPI_FROM_PMPI#define MPI_Alltoall PMPI_Alltoall/* This is the default implementation of alltoall. The algorithm is: Algorithm: MPI_Alltoall We use four algorithms for alltoall. For short messages and (comm_size >= 8), we use the algorithm by Jehoshua Bruck et al, IEEE TPDS, Nov. 1997. It is a store-and-forward algorithm that takes lgp steps. Because of the extra communication, the bandwidth requirement is (n/2).lgp.beta. Cost = lgp.alpha + (n/2).lgp.beta where n is the total amount of data a process needs to send to all other processes. For medium size messages and (short messages for comm_size < 8), we use an algorithm that posts all irecvs and isends and then does a waitall. We scatter the order of sources and destinations among the processes, so that all processes don't try to send/recv to/from the same process at the same time. For long messages and power-of-two number of processes, we use a pairwise exchange algorithm, which takes p-1 steps. We calculate the pairs by using an exclusive-or algorithm: for (i=1; i<comm_size; i++) dest = rank ^ i; This algorithm doesn't work if the number of processes is not a power of two. For a non-power-of-two number of processes, we use an algorithm in which, in step i, each process receives from (rank-i) and sends to (rank+i). Cost = (p-1).alpha + n.beta where n is the total amount of data a process needs to send to all other processes. Possible improvements: End Algorithm: MPI_Alltoall*//* begin:nested *//* not declared static because a machine-specific function may call this one in some cases */int MPIR_Alltoall( void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPID_Comm *comm_ptr ){ static const char FCNAME[] = "MPIR_Alltoall"; int comm_size, i, j, pof2; MPI_Aint sendtype_extent, recvtype_extent; MPI_Aint recvtype_true_extent, recvbuf_extent, recvtype_true_lb; int mpi_errno=MPI_SUCCESS, src, dst, rank, nbytes; MPI_Status status; int sendtype_size, pack_size, block, position, *displs, count; MPI_Datatype newtype; void *tmp_buf; MPI_Comm comm; MPI_Request *reqarray; MPI_Status *starray;#ifdef OLD MPI_Aint sendtype_true_extent, sendbuf_extent, sendtype_true_lb; int k, p, curr_cnt, dst_tree_root, my_tree_root; int last_recv_cnt, mask, tmp_mask, tree_root, nprocs_completed;#endif if (sendcount == 0) return MPI_SUCCESS; comm = comm_ptr->handle; comm_size = comm_ptr->local_size; rank = comm_ptr->rank; /* Get extent of send and recv types */ MPID_Datatype_get_extent_macro(recvtype, recvtype_extent); MPID_Datatype_get_extent_macro(sendtype, sendtype_extent); MPID_Datatype_get_size_macro(sendtype, sendtype_size); nbytes = sendtype_size * sendcount; /* check if multiple threads are calling this collective function */ MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr ); if ((nbytes <= MPIR_ALLTOALL_SHORT_MSG) && (comm_size >= 8)) { /* use the indexing algorithm by Jehoshua Bruck et al, * IEEE TPDS, Nov. 97 */ /* allocate temporary buffer */ NMPI_Pack_size(recvcount*comm_size, recvtype, comm, &pack_size); tmp_buf = MPIU_Malloc(pack_size); /* --BEGIN ERROR HANDLING-- */ if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ /* Do Phase 1 of the algorithim. Shift the data blocks on process i * upwards by a distance of i blocks. Store the result in recvbuf. */ mpi_errno = MPIR_Localcopy((char *) sendbuf + rank*sendcount*sendtype_extent, (comm_size - rank)*sendcount, sendtype, recvbuf, (comm_size - rank)*recvcount, recvtype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ mpi_errno = MPIR_Localcopy(sendbuf, rank*sendcount, sendtype, (char *) recvbuf + (comm_size-rank)*recvcount*recvtype_extent, rank*recvcount, recvtype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ /* Input data is now stored in recvbuf with datatype recvtype */ /* Now do Phase 2, the communication phase. It takes ceiling(lg p) steps. In each step i, each process sends to rank+2^i and receives from rank-2^i, and exchanges all data blocks whose ith bit is 1. */ /* allocate displacements array for indexed datatype used in communication */ displs = MPIU_Malloc(comm_size * sizeof(int)); /* --BEGIN ERROR HANDLING-- */ if (!displs) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ pof2 = 1; while (pof2 < comm_size) { dst = (rank + pof2) % comm_size; src = (rank - pof2 + comm_size) % comm_size; /* Exchange all data blocks whose ith bit is 1 */ /* Create an indexed datatype for the purpose */ count = 0; for (block=1; block<comm_size; block++) { if (block & pof2) { displs[count] = block * recvcount; count++; } } mpi_errno = NMPI_Type_create_indexed_block(count, recvcount, displs, recvtype, &newtype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ mpi_errno = NMPI_Type_commit(&newtype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ position = 0; mpi_errno = NMPI_Pack(recvbuf, 1, newtype, tmp_buf, pack_size, &position, comm); mpi_errno = MPIC_Sendrecv(tmp_buf, position, MPI_PACKED, dst, MPIR_ALLTOALL_TAG, recvbuf, 1, newtype, src, MPIR_ALLTOALL_TAG, comm, MPI_STATUS_IGNORE); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ mpi_errno = NMPI_Type_free(&newtype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ pof2 *= 2; } MPIU_Free(displs); MPIU_Free(tmp_buf); /* Rotate blocks in recvbuf upwards by (rank + 1) blocks. Need * a temporary buffer of the same size as recvbuf. */ /* get true extent of recvtype */ mpi_errno = NMPI_Type_get_true_extent(recvtype, &recvtype_true_lb, &recvtype_true_extent); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ recvbuf_extent = recvcount * comm_size * (MPIR_MAX(recvtype_true_extent, recvtype_extent)); tmp_buf = MPIU_Malloc(recvbuf_extent); /* --BEGIN ERROR HANDLING-- */ if (!tmp_buf) { mpi_errno = MPIR_Err_create_code( MPI_SUCCESS, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**nomem", 0 ); return mpi_errno; } /* --END ERROR HANDLING-- */ /* adjust for potential negative lower bound in datatype */ tmp_buf = (void *)((char*)tmp_buf - recvtype_true_lb); mpi_errno = MPIR_Localcopy((char *) recvbuf + (rank+1)*recvcount*recvtype_extent, (comm_size - rank - 1)*recvcount, recvtype, tmp_buf, (comm_size - rank - 1)*recvcount, recvtype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */ mpi_errno = MPIR_Localcopy(recvbuf, (rank+1)*recvcount, recvtype, (char *) tmp_buf + (comm_size-rank-1)*recvcount*recvtype_extent, (rank+1)*recvcount, recvtype); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno) { mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER, "**fail", 0); return mpi_errno; } /* --END ERROR HANDLING-- */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -