diff --git a/src/collective.jl b/src/collective.jl index 3dded32cc..a80f3f4d9 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -167,6 +167,38 @@ Scatter(sendbuf, T, comm; root::Integer=Cint(0)) = Scatter(sendbuf, ::Type{T}, root::Integer, comm::Comm) where {T} = Scatter!(sendbuf, Ref{T}(), root, comm)[] +""" + Scatter(arr::AbstractVector, comm::Comm; root::Cint=0) + +Splits a 1D array `arr` with elements of the same type in the `root` process into `nprocs=Comm_size(comm)` smaller 1D arrays. +The array `arr` is splitted in rank order, and the number of elements `n=length(arr)` can be not divisible by `nprocs`. +Each process with the rank `j` returns a smaller array with the number of elements +`j < rem(n,nprocs) ? div(n,nprocs) + 1 : div(n,nprocs)`. +""" +function Scatter(arr::Union{Nothing, AbstractVector}, comm; root=Cint(0)) + rank = MPI.Comm_rank(comm) + nprocs = MPI.Comm_size(comm) + + arr_len = 0 + elm_t = nothing + if rank == root + arr_len = length(arr) + elm_t = eltype(arr) + end + arr_len = MPI.Bcast(arr_len, root, comm) + elm_t = MPI.bcast(elm_t, root, comm) + + q,r = divrem(arr_len, nprocs) + count = rank < r ? (q+1) : q + local_arr = Vector{elm_t}(undef, count) + if rank == root + counts = [i < r ? (q+1) : q for i = 0:(nprocs - 1)] + return MPI.Scatterv!(MPI.VBuffer(arr, counts), MPI.Buffer(local_arr), root, comm) + else + return MPI.Scatterv!(nothing, MPI.Buffer(local_arr), root, comm) + end +end + """ scatter(objs::Union{AbstractVector, Nothing}, comm::Comm; root::Integer=0)