Skip to content

Commit d026536

Browse files
Keluaavchuravy
authored andcommitted
Added IReduce! and IAllreduce!
1 parent 71acbb7 commit d026536

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

src/collective.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,57 @@ function Reduce(object::T, op, root::Integer, comm::Comm) where {T}
716716
end
717717
end
718718

719+
## IReduce
720+
721+
"""
722+
IReduce!(sendbuf, recvbuf, op, comm::Comm[, req::AbstractRequest = Request()]; root::Integer=0)
723+
IReduce!(sendrecvbuf, op, comm::Comm[, req::AbstractRequest = Request()]; root::Integer=0)
724+
725+
Starts a nonblocking elementwise reduction using the operator `op` on the buffer `sendbuf` and
726+
stores the result in `recvbuf` on the process of rank `root`.
727+
728+
On non-root processes `recvbuf` is ignored, and can be `nothing`.
729+
730+
To perform the reduction in place, provide a single buffer `sendrecvbuf`.
731+
732+
Returns the [`AbstractRequest`](@ref) object for the nonblocking reduction.
733+
734+
# See also
735+
- [`Reduce!`](@ref) the equivalent blocking operation.
736+
- [`IAllreduce!`](@ref) to send reduction to all ranks.
737+
- [`Op`](@ref) for details on reduction operators.
738+
739+
# External links
740+
$(_doc_external("MPI_Ireduce"))
741+
"""
742+
IReduce!(sendrecvbuf, op, comm::Comm, req::AbstractRequest=Request(); root::Integer=Cint(0)) =
743+
IReduce!(sendrecvbuf, op, root, comm, req)
744+
IReduce!(sendbuf, recvbuf, op, comm::Comm, req::AbstractRequest=Request(); root::Integer=Cint(0)) =
745+
IReduce!(sendbuf, recvbuf, op, root, comm, req)
746+
747+
function IReduce!(rbuf::RBuffer, op::Union{Op,MPI_Op}, root::Integer, comm::Comm, req::AbstractRequest=Request())
748+
# int MPI_Ireduce(const void* sendbuf, void* recvbuf, int count,
749+
# MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm,
750+
# MPI_Request* req)
751+
API.MPI_Ireduce(rbuf.senddata, rbuf.recvdata, rbuf.count, rbuf.datatype, op, root, comm, req)
752+
setbuffer!(req, rbuf)
753+
return req
754+
end
755+
756+
IReduce!(rbuf::RBuffer, op, root::Integer, comm::Comm, req::AbstractRequest=Request()) =
757+
IReduce!(rbuf, Op(op, eltype(rbuf)), root, comm, req)
758+
IReduce!(sendbuf, recvbuf, op, root::Integer, comm::Comm, req::AbstractRequest=Request()) =
759+
IReduce!(RBuffer(sendbuf, recvbuf), op, root, comm, req)
760+
761+
# inplace
762+
function IReduce!(buf, op, root::Integer, comm::Comm, req::AbstractRequest=Request())
763+
if Comm_rank(comm) == root
764+
IReduce!(IN_PLACE, buf, op, root, comm, req)
765+
else
766+
IReduce!(buf, nothing, op, root, comm, req)
767+
end
768+
end
769+
719770
## Allreduce
720771

721772
# mutating
@@ -775,6 +826,45 @@ Allreduce(sendbuf::AbstractArray, op, comm::Comm) =
775826
Allreduce(obj::T, op, comm::Comm) where {T} =
776827
Allreduce!(Ref(obj), Ref{T}(), op, comm)[]
777828

829+
## IAllreduce
830+
831+
"""
832+
IAllreduce!(sendbuf, recvbuf, op, comm::Comm[, req::AbstractRequest = Request()])
833+
IAllreduce!(sendrecvbuf, op, comm::Comm[, req::AbstractRequest = Request()])
834+
835+
Starts a nonblocking elementwise reduction using the operator `op` on the buffer `sendbuf`, storing
836+
the result in the `recvbuf` of all processes in the group.
837+
838+
If only one `sendrecvbuf` buffer is provided, then the operation is performed in-place.
839+
840+
Returns the [`AbstractRequest`](@ref) object for the nonblocking reduction.
841+
842+
# See also
843+
- [`Allreduce!`](@ref) the equivalent blocking operation.
844+
- [`IReduce!`](@ref) to send reduction to a single rank.
845+
- [`Op`](@ref) for details on reduction operators.
846+
847+
# External links
848+
$(_doc_external("MPI_Iallreduce"))
849+
"""
850+
function IAllreduce!(rbuf::RBuffer, op::Union{Op, MPI_Op}, comm::Comm, req::AbstractRequest=Request())
851+
@assert isnull(req)
852+
# int MPI_Iallreduce(const void* sendbuf, void* recvbuf, int count,
853+
# MPI_Datatype datatype, MPI_Op op, MPI_Comm comm,
854+
# MPI_Request* req)
855+
API.MPI_Iallreduce(rbuf.senddata, rbuf.recvdata, rbuf.count, rbuf.datatype, op, comm, req)
856+
setbuffer!(req, rbuf)
857+
return req
858+
end
859+
IAllreduce!(rbuf::RBuffer, op, comm::Comm, req::AbstractRequest=Request()) =
860+
IAllreduce!(rbuf, Op(op, eltype(rbuf)), comm, req)
861+
IAllreduce!(sendbuf, recvbuf, op, comm::Comm, req::AbstractRequest=Request()) =
862+
IAllreduce!(RBuffer(sendbuf, recvbuf), op, comm, req)
863+
864+
# inplace
865+
IAllreduce!(rbuf, op, comm::Comm, req::AbstractRequest=Request()) =
866+
IAllreduce!(IN_PLACE, rbuf, op, comm, req)
867+
778868
## Scan
779869

780870
# mutating

test/test_allreduce.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ for T = [Int]
4343
vals = MPI.Allreduce(send_arr, op, MPI.COMM_WORLD)
4444
@test vals isa ArrayType{T}
4545
@test vals == comm_size .* send_arr
46+
47+
# Nonblocking
48+
recv_arr = ArrayType{T}(undef, size(send_arr))
49+
req = MPI.IAllreduce!(send_arr, recv_arr, op, MPI.COMM_WORLD)
50+
MPI.Wait(req)
51+
@test recv_arr == comm_size .* send_arr
52+
53+
# Nonblocking (IN_PLACE)
54+
recv_arr = copy(send_arr)
55+
synchronize()
56+
req = MPI.IAllreduce!(recv_arr, op, MPI.COMM_WORLD)
57+
MPI.Wait(req)
58+
@test recv_arr == comm_size .* send_arr
4659
end
4760
end
4861
end

test/test_reduce.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ for T = [Int]
111111
@test recv_arr isa ArrayType{T}
112112
@test recv_arr == sz .* view(send_arr, 2:3)
113113
end
114+
115+
# Nonblocking
116+
recv_arr = ArrayType{T}(undef, size(send_arr))
117+
req = MPI.IReduce!(send_arr, recv_arr, op, MPI.COMM_WORLD; root=root)
118+
MPI.Wait(req)
119+
if isroot
120+
@test recv_arr == sz .* send_arr
121+
end
122+
123+
# Nonblocking (IN_PLACE)
124+
recv_arr = copy(send_arr)
125+
req = MPI.IReduce!(recv_arr, op, MPI.COMM_WORLD; root=root)
126+
MPI.Wait(req)
127+
if isroot
128+
@test recv_arr == sz .* send_arr
129+
end
114130
end
115131
end
116132
end
@@ -127,6 +143,13 @@ if can_do_closures
127143
@test result === nothing
128144
end
129145

146+
recv_arr = isroot ? zeros(eltype(send_arr), size(send_arr)) : nothing
147+
req = MPI.IReduce!(send_arr, recv_arr, +, MPI.COMM_WORLD; root=root)
148+
MPI.Wait(req)
149+
if rank == root
150+
@test recv_arr [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
151+
end
152+
130153
MPI.Barrier( MPI.COMM_WORLD )
131154
end
132155

0 commit comments

Comments
 (0)