Skip to content

Commit f46414d

Browse files
authored
Update Reduce example to use custom datatypes and operators (#419)
Based on my JuliaCon talk.
1 parent f794131 commit f46414d

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

docs/examples/03-reduce.jl

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,50 @@
1-
using MPI
1+
# This example shows how to use custom datatypes and reduction operators
2+
# It computes the variance in parallel in a numerically stable way
3+
4+
using MPI, Statistics
5+
26
MPI.Init()
7+
const comm = MPI.COMM_WORLD
8+
const root = 0
39

4-
comm = MPI.COMM_WORLD
5-
root = 0
10+
# Define a custom struct
11+
# This contains the summary statistics (mean, variance, length) of a vector
12+
struct SummaryStat
13+
mean::Float64
14+
var::Float64
15+
n::Float64
16+
end
17+
function SummaryStat(X::AbstractArray)
18+
m = mean(X)
19+
v = varm(X,m, corrected=false)
20+
n = length(X)
21+
SummaryStat(m,v,n)
22+
end
623

7-
r = MPI.Comm_rank(comm)
24+
# Define a custom reduction operator
25+
# this computes the pooled mean, pooled variance and total length
26+
function pool(S1::SummaryStat, S2::SummaryStat)
27+
n = S1.n + S2.n
28+
m = (S1.mean*S1.n + S2.mean*S2.n) / n
29+
v = (S1.n * (S1.var + S1.mean * (S1.mean-m)) +
30+
S2.n * (S2.var + S2.mean * (S2.mean-m)))/n
31+
SummaryStat(m,v,n)
32+
end
833

9-
sr = MPI.Reduce(r, +, root, comm)
34+
X = randn(10,3) .* [1,3,7]'
35+
36+
# Perform a scalar reduction
37+
summ = MPI.Reduce(SummaryStat(X), pool, root, comm)
1038

1139
if MPI.Comm_rank(comm) == root
12-
println("sum of ranks = $sr")
40+
@show summ.var
1341
end
1442

43+
# Perform a vector reduction:
44+
# the reduction operator is applied elementwise
45+
col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, root, comm)
46+
47+
if MPI.Comm_rank(comm) == root
48+
col_var = map(summ -> summ.var, col_summ)
49+
@show col_var
50+
end

0 commit comments

Comments
 (0)