Skip to content

Commit 10972e9

Browse files
MPI Distributed Graph API (#597)
* Add MPI distributed graph API * Add neighbor collectives Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>
1 parent 623d8f4 commit 10972e9

13 files changed

+643
-5
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using Test
2+
using MPI
3+
4+
MPI.Init()
5+
6+
comm = MPI.COMM_WORLD
7+
size = MPI.Comm_size(comm)
8+
rank = MPI.Comm_rank(comm)
9+
10+
#
11+
# Setup the following communication graph
12+
#
13+
# +-----+
14+
# | |
15+
# v v
16+
# 0<-+ 3
17+
# ^ | ^
18+
# | | |
19+
# v | v
20+
# 1 +--2
21+
# ^ |
22+
# | |
23+
# +-----+
24+
#
25+
#
26+
27+
if rank == 0
28+
dest = Cint[1,3]
29+
degree = Cint[length(dest)]
30+
elseif rank == 1
31+
dest = Cint[0]
32+
degree = Cint[length(dest)]
33+
elseif rank == 2
34+
dest = Cint[3,0,1]
35+
degree = Cint[length(dest)]
36+
elseif rank == 3
37+
dest = Cint[0,2,1]
38+
degree = Cint[length(dest)]
39+
end
40+
41+
source = Cint[rank]
42+
graph_comm = MPI.Dist_graph_create(comm, source, degree, dest)
43+
44+
#
45+
# Now send the rank across the edges.
46+
#
47+
# Version 1: use allgather primitive
48+
#
49+
50+
send = [rank]
51+
if rank == 0
52+
recv = [-1, -1, -1]
53+
elseif rank == 1
54+
recv = [-1, -1, -1]
55+
elseif rank == 2
56+
recv = [-1]
57+
elseif rank == 3
58+
recv = [-1, -1]
59+
end
60+
61+
MPI.Neighbor_allgather!(send, recv, graph_comm);
62+
63+
println("rank = $(rank): $(recv)")
64+
65+
#
66+
# Version 2: use alltoall primitive
67+
#
68+
69+
if rank == 0
70+
send = [rank, rank]
71+
recv = [-1, -1, -1]
72+
elseif rank == 1
73+
send = [rank]
74+
recv = [-1, -1, -1]
75+
elseif rank == 2
76+
send = [rank, rank, rank]
77+
recv = [-1]
78+
elseif rank == 3
79+
send = [rank, rank, rank]
80+
recv = [-1, -1]
81+
end
82+
83+
MPI.Neighbor_alltoall!(UBuffer(send,1), UBuffer(recv,1), graph_comm);
84+
85+
println("rank = $(rank): $(recv)")
86+
87+
#
88+
# Now send the rank exactly rank times across the edges.
89+
#
90+
# Rank i receives i+1 values from each adjacent process
91+
if rank == 0
92+
send = [rank, rank,
93+
rank, rank, rank, rank]
94+
send_count = [2, 4]
95+
96+
recv = [-1, -1, -1]
97+
recv_count = [1, 1, 1]
98+
elseif rank == 1
99+
send = [rank]
100+
send_count = [1]
101+
102+
recv = [-1, -1, -1, -1, -1, -1]
103+
recv_count = [2, 2, 2]
104+
elseif rank == 2
105+
send = [rank, rank, rank, rank,
106+
rank,
107+
rank,rank]
108+
send_count = [4, 1, 2]
109+
110+
recv = [-1, -1, -1]
111+
recv_count = [3]
112+
elseif rank == 3
113+
send = [rank,
114+
rank, rank,rank,
115+
rank, rank]
116+
send_count = [1, 3, 2]
117+
118+
recv = [-1, -1, -1, -1, -1, -1, -1, -1]
119+
recv_count = [4, 4]
120+
end
121+
122+
MPI.Neighbor_alltoallv!(VBuffer(send,send_count), VBuffer(recv,recv_count), graph_comm);
123+
println("rank = $(rank): $(recv)")
124+
125+
MPI.Finalize()

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ EXAMPLES = [
1212
"Scatterv and Gatherv" => "examples/06-scatterv.md",
1313
"Active RMA" => "examples/07-rma_active.md",
1414
"Passive RMA" => "examples/08-rma_passive.md",
15+
"Graph Communication" => "examples/09-graph_communication.md",
1516
]
1617

1718
examples_md_dir = joinpath(@__DIR__,"src/examples")
@@ -36,9 +37,9 @@ for (example_title, example_md) in EXAMPLES
3637
println(mdfile)
3738

3839
println(mdfile, "```")
39-
println(mdfile, "> mpiexecjl -n 3 julia $example_jl")
40+
println(mdfile, "> mpiexecjl -n 4 julia $example_jl")
4041
cd(@__DIR__) do
41-
write(mdfile, mpiexec(cmd -> read(`$cmd -n 3 $(Base.julia_cmd()) --project $example_jl`)))
42+
write(mdfile, mpiexec(cmd -> read(`$cmd -n 4 $(Base.julia_cmd()) --project $example_jl`)))
4243
end
4344
println(mdfile, "```")
4445
end

docs/src/reference/collective.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ MPI.Gatherv!
2525
MPI.Allgather!
2626
MPI.Allgather
2727
MPI.Allgatherv!
28+
MPI.Neighbor_allgather!
29+
MPI.Neighbor_allgatherv!
2830
```
2931

3032
### Scatter
@@ -41,6 +43,8 @@ MPI.Scatterv!
4143
MPI.Alltoall!
4244
MPI.Alltoall
4345
MPI.Alltoallv!
46+
MPI.Neighbor_alltoall!
47+
MPI.Neighbor_alltoallv!
4448
```
4549

4650
## Reduce/Scan

docs/src/reference/topology.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@ MPI.Cart_rank
99
MPI.Cart_shift
1010
MPI.Cart_sub
1111
MPI.Cartdim_get
12+
MPI.Dist_graph_create
13+
MPI.Dist_graph_create_adjacent
14+
MPI.Dist_graph_neighbors_count
15+
MPI.Dist_graph_neighbors!
1216
```

src/collective.jl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,119 @@ Exscan(sendbuf::AbstractArray, op, comm::Comm) =
839839
Exscan!(sendbuf, similar(sendbuf), op, comm)
840840
Exscan(object::T, op, comm::Comm) where {T} =
841841
Exscan!(Ref(object), Ref{T}(), op, comm)[]
842+
843+
"""
844+
Neighbor_alltoall!(sendbuf::UBuffer, recvbuf::UBuffer, comm::Comm)
845+
846+
Perform an all-to-all communication along the directed edges of the graph with fixed size messages.
847+
848+
See also [`MPI.Alltoall!`](@ref).
849+
850+
# External links
851+
$(_doc_external("MPI_Neighbor_alltoall"))
852+
"""
853+
function Neighbor_alltoall!(sendbuf::UBuffer, recvbuf::UBuffer, graph_comm::Comm)
854+
# int MPI_Neighbor_alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf,
855+
# int recvcount, MPI_Datatype recvtype, MPI_Comm graph_comm)
856+
@mpichk ccall((:MPI_Neighbor_alltoall, libmpi), Cint,
857+
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, MPI_Comm),
858+
sendbuf.data, sendbuf.count, sendbuf.datatype,
859+
recvbuf.data, recvbuf.count, recvbuf.datatype,
860+
graph_comm) v"3.0"
861+
return recvbuf.data
862+
end
863+
864+
Neighbor_alltoall!(sendbuf::InPlace, recvbuf::UBuffer, graph_comm::Comm) =
865+
Neighbor_alltoall!(UBuffer(IN_PLACE), recvbuf, graph_comm)
866+
Neighbor_alltoall!(sendrecvbuf::UBuffer, graph_comm::Comm) =
867+
Neighbor_alltoall!(IN_PLACE, sendrecvbuf, comm)
868+
Neighbor_alltoall(sendbuf::UBuffer, graph_comm::Comm) =
869+
Neighbor_alltoall!(sendbuf, similar(sendbuf), graph_comm)
870+
871+
"""
872+
Neighbor_alltoallv!(sendbuf::VBuffer, recvbuf::VBuffer, graph_comm::Comm)
873+
874+
Perform an all-to-all communication along the directed edges of the graph with variable size messages.
875+
876+
See also [`MPI.Alltoallv!`](@ref).
877+
878+
# External links
879+
$(_doc_external("MPI_Neighbor_alltoallv"))
880+
"""
881+
function Neighbor_alltoallv!(sendbuf::VBuffer, recvbuf::VBuffer, graph_comm::Comm)
882+
# int MPI_Neighbor_alltoallv!(const void* sendbuf, const int sendcounts[],
883+
# const int sdispls[], MPI_Datatype sendtype, void* recvbuf,
884+
# const int recvcounts[], const int rdispls[],
885+
# MPI_Datatype recvtype, MPI_Comm comm)
886+
@mpichk ccall((:MPI_Neighbor_alltoallv, libmpi), Cint,
887+
(MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype,
888+
MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype,
889+
MPI_Comm),
890+
sendbuf.data, sendbuf.counts, sendbuf.displs, sendbuf.datatype,
891+
recvbuf.data, recvbuf.counts, recvbuf.displs, recvbuf.datatype,
892+
graph_comm) v"3.0"
893+
return recvbuf.data
894+
end
895+
896+
"""
897+
Neighbor_allgather!(sendbuf::Buffer, recvbuf::UBuffer, comm::Comm)
898+
899+
Perform an all-gather communication along the directed edges of the graph.
900+
901+
See also [`MPI.Allgather!`](@ref).
902+
903+
# External links
904+
$(_doc_external("MPI_Neighbor_allgather"))
905+
"""
906+
function Neighbor_allgather!(sendbuf::Buffer, recvbuf::UBuffer, graph_comm::Comm)
907+
# int MPI_Neighbor_allgather(const void* sendbuf, int sendcount,
908+
# MPI_Datatype sendtype, void* recvbuf, int recvcount,
909+
# MPI_Datatype recvtype, MPI_Comm comm)
910+
@mpichk ccall((:MPI_Neighbor_allgather, libmpi), Cint,
911+
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, MPI_Comm),
912+
sendbuf.data, sendbuf.count, sendbuf.datatype,
913+
recvbuf.data, recvbuf.count, recvbuf.datatype, graph_comm) v"3.0"
914+
915+
return recvbuf.data
916+
end
917+
Neighbor_allgather!(sendbuf, recvbuf::UBuffer, graph_comm::Comm) =
918+
Neighbor_allgather!(Buffer_send(sendbuf), recvbuf, graph_comm)
919+
920+
Neighbor_allgather!(sendbuf::Union{Ref,AbstractArray}, recvbuf::AbstractArray, graph_comm::Comm) =
921+
Neighbor_allgather!(sendbuf, UBuffer(recvbuf, length(sendbuf)), graph_comm)
922+
923+
924+
function Neighbor_allgather!(sendrecvbuf::UBuffer, graph_comm::Comm)
925+
Neighbor_allgather!(IN_PLACE, sendrecvbuf, graph_comm)
926+
end
927+
928+
"""
929+
Neighbor_allgatherv!(sendbuf::Buffer, recvbuf::VBuffer, comm::Comm)
930+
931+
Perform an all-gather communication along the directed edges of the graph with variable sized data.
932+
933+
See also [`MPI.Allgatherv!`](@ref).
934+
935+
# External links
936+
$(_doc_external("MPI_Neighbor_allgatherv"))
937+
"""
938+
function Neighbor_allgatherv!(sendbuf::Buffer, recvbuf::VBuffer, graph_comm::Comm)
939+
# int MPI_Neighbor_allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
940+
# void *recvbuf, const int recvcounts[], const int displs[],
941+
# MPI_Datatype recvtype, MPI_Comm comm)
942+
@mpichk ccall((:MPI_Neighbor_allgatherv, libmpi), Cint,
943+
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPI_Comm),
944+
sendbuf.data, sendbuf.count, sendbuf.datatype,
945+
recvbuf.data, recvbuf.counts, recvbuf.displs, recvbuf.datatype, graph_comm) v"3.0"
946+
return recvbuf.data
947+
end
948+
Neighbor_allgatherv!(sendbuf, recvbuf::VBuffer, graph_comm::Comm) =
949+
Neighbor_allgatherv!(Buffer_send(sendbuf), recvbuf, graph_comm)
950+
951+
Neighbor_allgatherv!(sendbuf::Union{Ref,AbstractArray}, recvbuf::AbstractArray, graph_comm::Comm) =
952+
Neighbor_allgatherv!(sendbuf, VBuffer(recvbuf, length(sendbuf)), graph_comm)
953+
954+
955+
function Neighbor_allgatherv!(sendrecvbuf::VBuffer, graph_comm::Comm)
956+
Neighbor_allgatherv!(IN_PLACE, sendrecvbuf, graph_comm)
957+
end

src/consts/mpich.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ const MPI_Win_errhandler_fn = MPI_Win_errhandler_function
224224
@const_ref MPI_ARGV_NULL Ptr{Cvoid} C_NULL
225225
@const_ref MPI_ARGVS_NULL Ptr{Cvoid} C_NULL
226226

227-
@const_ref MPI_UNWEIGHTED Ptr{Cint} cglobal((:MPI_UNWEIGHTED, libmpi), Ptr{Cint})
228-
@const_ref MPI_WEIGHTS_EMPTY Ptr{Cint} cglobal((:MPI_WEIGHTS_EMPTY, libmpi), Ptr{Cint})
227+
@const_ref MPI_UNWEIGHTED Ptr{Cint} unsafe_load(cglobal((:MPI_UNWEIGHTED, libmpi), Ptr{Cint}))
228+
@const_ref MPI_WEIGHTS_EMPTY Ptr{Cint} unsafe_load(cglobal((:MPI_WEIGHTS_EMPTY, libmpi), Ptr{Cint}))
229229
@const_ref MPI_BOTTOM Ptr{Cvoid} C_NULL
230230
@const_ref MPI_IN_PLACE Ptr{Cvoid} -1
231231

src/error.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ macro mpichk(expr, min_version=nothing)
2828
fn = expr.args[2].args[1].value
2929
if isnothing(dlsym(libmpi_handle, fn; throw_error=false))
3030
return quote
31-
throw(FeatureLevelError($fn, $min_version))
31+
throw(FeatureLevelError($(QuoteNode(fn)), $min_version))
3232
end
3333
end
3434
end

0 commit comments

Comments
 (0)