📄 mpido_allgatherv.c
字号:
dt_null, recv_true_lb); if (sendbuf != MPI_IN_PLACE) MPIDI_Datatype_get_info(sendcount, sendtype, config.send_contig, send_size, dt_null, send_true_lb); int buffer_sum = 0; { int i = 0; if (0 != displs[0]) config.recv_continuous = 0; for (i = 1; i < comm_ptr->local_size; ++i) { buffer_sum += recvcounts[i - 1]; if (buffer_sum != displs[i]) config.recv_continuous = 0; if (!config.recv_continuous) break; } buffer_sum += recvcounts[comm_ptr->local_size - 1]; } buffer_sum *= recv_size; MPIDO_Allreduce(MPI_IN_PLACE, &config, 3, MPI_INT, MPI_BAND, comm_ptr); /* determine which protocol to use */ /* 1) Tree allreduce * a) Need tree allreduce for this communicator * b) User must be ok with allgatherv via allreduce * c) Datatypes must be continguous * d) Count must be a multiple of 4 since tree doesn't support * chars right now */ int treereduce = comm_ptr->dcmf.allreducetree && MPIDI_CollectiveProtocols.allgatherv.useallreduce && config.recv_contig && config.send_contig && config.recv_continuous && buffer_sum % 4 ==0; /* 2) Tree bcast * a) Need tree bcast for this communicator * b) User must be ok with allgatherv via bcast */ int treebcast = comm_ptr->dcmf.bcasttree && MPIDI_CollectiveProtocols.allgatherv.usebcast; /* 3) Alltoall * a) Need torus alltoall for this communicator * b) User must be ok with allgatherv via alltoall * c) Need contiguous datatypes */ int usealltoall = comm_ptr->dcmf.alltoalls && MPIDI_CollectiveProtocols.allgatherv.usealltoallv && config.recv_contig && config.send_contig;#warning assume same cutoff for allgather if(treereduce && treebcast && sendcount > 65536) {// if(comm_ptr->rank ==0 )fprintf(stderr,"sendcount: %d, calling bcast\n", sendcount); result = MPIDO_Allgatherv_Bcast(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr); } else if(treereduce && treebcast && sendcount <= 65536) {// if(comm_ptr->rank ==0 )fprintf(stderr,"sendcount: %d, calling allreduce\n", sendcount); result = MPIDO_Allgatherv_Allreduce(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr, send_true_lb, recv_true_lb, send_size, recv_size, buffer_sum); } else if(treereduce) {// if(comm_ptr->rank ==0 )fprintf(stderr,"sendcount: %d, only tree allreduce\n", sendcount); result = MPIDO_Allgatherv_Allreduce(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr, send_true_lb, recv_true_lb, send_size, recv_size, buffer_sum); } else if(treebcast) {// if(comm_ptr->rank ==0 )fprintf(stderr,"sendcount: %d, only tree bcast\n", sendcount); result = MPIDO_Allgatherv_Bcast(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr); } else if(usealltoall) result = MPIDO_Allgatherv_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr, send_true_lb, recv_true_lb, recv_size); else return MPIR_Allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr); return result;}#if 0 /* not worth doing on the torus */ if (MPIDI_CollectiveProtocols.allgatherv.useallreduce && comm_ptr->dcmf.allreducetree && config.recv_contig && config.send_contig && config.recv_continuous && buffer_sum % 4 == 0) { //if (0==comm_ptr->rank) puts("allreduce allgatherv"); result = MPIDO_Allgatherv_Allreduce(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr, send_true_lb, recv_true_lb, send_size, recv_size, buffer_sum); } /* again, too slow if we only have a rectangle bcast */ else if (MPIDI_CollectiveProtocols.allgatherv.usebcast && comm_ptr->dcmf.bcasttree) { //if (0==comm_ptr->rank) puts("bcast allgatherv"); result = MPIDO_Allgatherv_Bcast(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr); } else if (MPIDI_CollectiveProtocols.allgatherv.usealltoallv && comm_ptr->dcmf.alltoalls && config.recv_contig && config.send_contig) { //if (0==comm_ptr->rank) puts("all2all allgatherv"); result = MPIDO_Allgatherv_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr, send_true_lb, recv_true_lb, recv_size); } else { //if (0==comm_ptr->rank) puts("mpich2 allgatherv"); return MPIR_Allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr); } return result;}#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -