Skip to content

Commit 6cb0fb0

Browse files
authored
Add MPI.Dist_graph_neighbors (#707)
This patch adds `MPI.Dist_graph_neighbors` as a convenience wrapper around `MPI.Dist_graph_neighbors_count` and `MPI.Dist_graph_neighbors!` that allocates the required result vectors automatically.
1 parent 7521b2f commit 6cb0fb0

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

docs/src/reference/topology.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ MPI.Dist_graph_create
1313
MPI.Dist_graph_create_adjacent
1414
MPI.Dist_graph_neighbors_count
1515
MPI.Dist_graph_neighbors!
16+
MPI.Dist_graph_neighbors
1617
```

src/topology.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,29 @@ function Dist_graph_neighbors!(graph_comm::Comm, sources::Vector{Cint}, destinat
377377
destination_weights = Array{Cint}(undef,0)
378378
Dist_graph_neighbors!(graph_comm, sources::Vector{Cint}, source_weights, destinations::Vector{Cint}, destination_weights)
379379
end
380+
381+
"""
382+
Dist_graph_neighbors(graph_comm::Comm)
383+
384+
Return `(sources, source_weights, destinations, destination_weights)` of the graph
385+
communicator `graph_comm`. For unweighted graphs `source_weights` and `destination_weights`
386+
are `nothing`.
387+
388+
This function is a wrapper around [`MPI.Dist_graph_neighbors_count`](@ref) and
389+
[`MPI.Dist_graph_neighbors!`](@ref) that automatically handles the allocation of the result
390+
vectors.
391+
"""
392+
function Dist_graph_neighbors(graph_comm::Comm)
393+
indegree, outdegree, weighted = Dist_graph_neighbors_count(graph_comm)
394+
sources = Vector{Cint}(undef, indegree)
395+
destinations = Vector{Cint}(undef, outdegree)
396+
if weighted
397+
source_weights = Vector{Cint}(undef, indegree)
398+
destination_weights = Vector{Cint}(undef, outdegree)
399+
Dist_graph_neighbors!(graph_comm, sources, source_weights, destinations, destination_weights)
400+
return sources, source_weights, destinations, destination_weights
401+
else
402+
Dist_graph_neighbors!(graph_comm, sources, destinations)
403+
return sources, nothing, destinations, nothing
404+
end
405+
end

test/test_neighbor_comm.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using Test
2+
using MPI
3+
4+
MPI.Init()
5+
6+
const comm = MPI.COMM_WORLD
7+
const rank = MPI.Comm_rank(comm)
8+
const comm_size = MPI.Comm_size(comm)
9+
10+
# Generate a ring graph with optional weights: 0 -> 1 -> 2 -> ... -> comm_size -> 0
11+
const prev_rank = (rank + comm_size - 1) % comm_size
12+
const next_rank = (rank + 1) % comm_size
13+
function ring_graph(; weighted)
14+
sources = Cint[rank]
15+
degrees = Cint.(length.(sources))
16+
destinations = Cint[next_rank]
17+
weights = weighted ? Cint[rank + comm_size] : MPI.UNWEIGHTED
18+
return MPI.Dist_graph_create(comm, sources, degrees, destinations; weights=weights)
19+
end
20+
21+
# Unweighted graph
22+
let
23+
graph_comm = ring_graph(; weighted=false)
24+
indeg, outdeg, weighted = MPI.Dist_graph_neighbors_count(graph_comm)
25+
@test indeg == outdeg == 1
26+
@test !weighted
27+
src = Vector{Cint}(undef, indeg)
28+
dst = Vector{Cint}(undef, outdeg)
29+
MPI.Dist_graph_neighbors!(graph_comm, src, dst)
30+
src2, srcw2, dst2, dstw2 = MPI.Dist_graph_neighbors(graph_comm)
31+
@test src == src2 == Cint[prev_rank]
32+
@test dst == dst2 == Cint[next_rank]
33+
@test srcw2 === dstw2 === nothing
34+
end
35+
36+
# Weighted graph
37+
let
38+
graph_comm = ring_graph(; weighted=true)
39+
indeg, outdeg, weighted = MPI.Dist_graph_neighbors_count(graph_comm)
40+
@test indeg == outdeg == 1
41+
@test weighted
42+
src = Vector{Cint}(undef, indeg)
43+
dst = Vector{Cint}(undef, outdeg)
44+
MPI.Dist_graph_neighbors!(graph_comm, src, dst)
45+
src2 = Vector{Cint}(undef, indeg)
46+
srcw2 = Vector{Cint}(undef, indeg)
47+
dst2 = Vector{Cint}(undef, outdeg)
48+
dstw2 = Vector{Cint}(undef, outdeg)
49+
MPI.Dist_graph_neighbors!(graph_comm, src2, srcw2, dst2, dstw2)
50+
src3, srcw3, dst3, dstw3 = MPI.Dist_graph_neighbors(graph_comm)
51+
@test src == src2 == src3 == Cint[prev_rank]
52+
@test dst == dst2 == dst3 == Cint[next_rank]
53+
@test srcw2 == srcw3 == Cint[prev_rank + comm_size]
54+
end
55+
56+
MPI.Finalize()
57+
@test MPI.Finalized()

0 commit comments

Comments
 (0)