Skip to content

Commit 4de1f4b

Browse files
Speed up IGW example with threaded mapreduce
1 parent 632eef3 commit 4de1f4b

File tree

4 files changed

+150
-13
lines changed

4 files changed

+150
-13
lines changed

examples/Manifest.toml

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.8.5"
44
manifest_format = "2.0"
5-
project_hash = "7ec2fabea202c1fbfc0bf694c476ce42ac582098"
5+
project_hash = "1e5a4795567d9d355e43d11e9245fcf683ab6bb8"
66

77
[[deps.AMD]]
88
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]
@@ -27,6 +27,11 @@ git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef"
2727
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2828
version = "3.6.1"
2929

30+
[[deps.ArgCheck]]
31+
git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
32+
uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
33+
version = "2.3.0"
34+
3035
[[deps.ArgTools]]
3136
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
3237
version = "1.1.1"
@@ -88,9 +93,20 @@ git-tree-sha1 = "6ef8fc1d77b60f41041d59ce61ef9eb41ed97a83"
8893
uuid = "aae01518-5342-5314-be14-df237901396f"
8994
version = "0.17.18"
9095

96+
[[deps.BangBang]]
97+
deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"]
98+
git-tree-sha1 = "54b00d1b93791f8e19e31584bd30f2cb6004614b"
99+
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
100+
version = "0.3.38"
101+
91102
[[deps.Base64]]
92103
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
93104

105+
[[deps.Baselet]]
106+
git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
107+
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
108+
version = "0.1.1"
109+
94110
[[deps.BitFlags]]
95111
git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d"
96112
uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35"
@@ -195,7 +211,7 @@ version = "0.4.2"
195211
deps = ["Adapt", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
196212
path = ".."
197213
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
198-
version = "0.10.33"
214+
version = "0.10.36"
199215

200216
[[deps.ClimaCorePlots]]
201217
deps = ["ClimaCore", "RecipesBase", "StaticArrays", "TriplotBase"]
@@ -204,7 +220,7 @@ uuid = "cf7c7e5a-b407-4c48-9047-11a94a308626"
204220
version = "0.2.5"
205221

206222
[[deps.ClimaCoreTempestRemap]]
207-
deps = ["ClimaComms", "ClimaCore", "Dates", "LinearAlgebra", "MPI", "NCDatasets", "PkgVersion", "TempestRemap_jll", "Test"]
223+
deps = ["ClimaComms", "ClimaCore", "Dates", "LinearAlgebra", "NCDatasets", "PkgVersion", "TempestRemap_jll", "Test"]
208224
path = "../lib/ClimaCoreTempestRemap"
209225
uuid = "d934ef94-cdd4-4710-83d6-720549644b70"
210226
version = "0.3.8"
@@ -290,6 +306,11 @@ git-tree-sha1 = "02d2316b7ffceff992f3096ae48c7829a8aa0638"
290306
uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657"
291307
version = "0.1.3"
292308

309+
[[deps.CompositionsBase]]
310+
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
311+
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
312+
version = "0.1.2"
313+
293314
[[deps.ConcurrentUtilities]]
294315
deps = ["Serialization", "Sockets"]
295316
git-tree-sha1 = "b306df2650947e9eb100ec125ff8c65ca2053d30"
@@ -345,6 +366,11 @@ version = "1.0.0"
345366
deps = ["Printf"]
346367
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
347368

369+
[[deps.DefineSingletons]]
370+
git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c"
371+
uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52"
372+
version = "0.1.2"
373+
348374
[[deps.DelimitedFiles]]
349375
deps = ["Mmap"]
350376
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
@@ -681,6 +707,11 @@ git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428"
681707
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
682708
version = "0.1.3"
683709

710+
[[deps.InitialValues]]
711+
git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3"
712+
uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c"
713+
version = "0.3.1"
714+
684715
[[deps.IntelOpenMP_jll]]
685716
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
686717
git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83"
@@ -1028,6 +1059,12 @@ git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
10281059
uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e"
10291060
version = "0.3.2"
10301061

1062+
[[deps.MicroCollections]]
1063+
deps = ["BangBang", "InitialValues", "Setfield"]
1064+
git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e"
1065+
uuid = "128add7d-3638-4c79-886c-908ea0c25c34"
1066+
version = "0.1.4"
1067+
10311068
[[deps.MicrosoftMPI_jll]]
10321069
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
10331070
git-tree-sha1 = "a8027af3d1743b3bfae34e54872359fdebb31422"
@@ -1361,6 +1398,12 @@ git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
13611398
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
13621399
version = "1.2.2"
13631400

1401+
[[deps.Referenceables]]
1402+
deps = ["Adapt"]
1403+
git-tree-sha1 = "e681d3bfa49cd46c3c161505caddf20f0e62aaa9"
1404+
uuid = "42d2dcc6-99eb-4e98-b66c-637b7d73030e"
1405+
version = "0.1.2"
1406+
13641407
[[deps.RelocatableFolders]]
13651408
deps = ["SHA", "Scratch"]
13661409
git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691"
@@ -1497,6 +1540,12 @@ git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880"
14971540
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
14981541
version = "2.2.0"
14991542

1543+
[[deps.SplittablesBase]]
1544+
deps = ["Setfield", "Test"]
1545+
git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5"
1546+
uuid = "171d559e-b47b-412a-8079-5efa626c420e"
1547+
version = "0.1.15"
1548+
15001549
[[deps.Static]]
15011550
deps = ["IfElse"]
15021551
git-tree-sha1 = "7f5a513baec6f122401abfc8e9c074fdac54f6c1"
@@ -1575,7 +1624,7 @@ uuid = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea"
15751624
version = "0.12.2"
15761625

15771626
[[deps.TempestRemap_jll]]
1578-
deps = ["Artifacts", "JLLWrappers", "Libdl", "NetCDF_jll", "OpenBLAS32_jll", "Pkg"]
1627+
deps = ["Artifacts", "HDF5_jll", "JLLWrappers", "Libdl", "NetCDF_jll", "OpenBLAS32_jll", "Pkg"]
15791628
git-tree-sha1 = "88c3818a492ad1a94b1aa440b01eab5d133209ff"
15801629
uuid = "8573a8c5-1df0-515e-a024-abad257ee284"
15811630
version = "2.1.6+1"
@@ -1602,6 +1651,12 @@ git-tree-sha1 = "c97f60dd4f2331e1a495527f80d242501d2f9865"
16021651
uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5"
16031652
version = "0.5.1"
16041653

1654+
[[deps.ThreadsX]]
1655+
deps = ["ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "Setfield", "SplittablesBase", "Transducers"]
1656+
git-tree-sha1 = "34e6bcf36b9ed5d56489600cf9f3c16843fa2aa2"
1657+
uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
1658+
version = "0.1.11"
1659+
16051660
[[deps.TimerOutputs]]
16061661
deps = ["ExprTools", "Printf"]
16071662
git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7"
@@ -1614,6 +1669,12 @@ git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769"
16141669
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
16151670
version = "0.9.13"
16161671

1672+
[[deps.Transducers]]
1673+
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
1674+
git-tree-sha1 = "25358a5f2384c490e98abd565ed321ffae2cbb37"
1675+
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
1676+
version = "0.4.76"
1677+
16171678
[[deps.TriangularSolve]]
16181679
deps = ["CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "LoopVectorization", "Polyester", "Static", "VectorizationBase"]
16191680
git-tree-sha1 = "31eedbc0b6d07c08a700e26d31298ac27ef330eb"

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2828
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2929
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
3030
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
31+
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
3132
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3233

3334
[compat]

examples/hybrid/plane/inertial_gravity_wave.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ function linear_solution_cache(ᶜlocal_geometry, ᶠlocal_geometry)
282282
@time "ρfb_init_coefs!" IGWU.ρfb_init_coefs!(FT, ρfb_init_array_params)
283283
(; ρfb_init_array, ᶜρb_init_xz, unit_integral) = ρfb_init_array_params
284284
max_ikx, max_ikz = (size(ρfb_init_array) .- 1) 2
285+
286+
get_xz(lg) = (; x = lg.coordinates.x, z = lg.coordinates.z)
287+
ᶠxz = get_xz.(ᶠlocal_geometry)
288+
ᶜxz = get_xz.(ᶜlocal_geometry)
289+
285290
ᶜp₀ = @. p₀(ᶜz)
286291
return (;
287292
# globals
@@ -305,6 +310,8 @@ function linear_solution_cache(ᶜlocal_geometry, ᶠlocal_geometry)
305310
ᶠx = ᶠlocal_geometry.coordinates.x,
306311
ᶜz,
307312
ᶠz,
313+
ᶜxz,
314+
ᶠxz,
308315

309316
# background state
310317
ᶜp₀,

examples/hybrid/plane/intertial_gravity_wave_utils.jl

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,83 @@ function ρfb_init_coefs!(::Type{FT}, params) where {FT}
3535
end
3636

3737
function Bretherton_transforms!(lin_cache, t, ::Type{FT}) where {FT}
38-
# Bretherton_transforms_partial_sums! is fastest because
39-
# we can multithread across
40-
# `Iterators.product((-max_ikx):max_ikx, (-max_ikz):max_ikz)`
41-
# and apply sums for center and face fields. Using mapreduce requires
42-
# two calls and, as a result in ~20 slower.
43-
44-
Bretherton_transforms_original!(lin_cache, t, FT)
45-
# Bretherton_transforms_partial_sums!(lin_cache, t, FT)
46-
# Bretherton_transforms_threaded_mapreduce!(lin_cache, t, FT)
38+
# Bretherton_transforms! is the most computationally
39+
# expensive part of this example and was therefore
40+
# optimized a bit.
41+
# Bretherton_transforms_original!(lin_cache, t, FT)
42+
Bretherton_transforms_threaded_mapreduce!(lin_cache, t, FT)
43+
end
44+
45+
import ThreadsX
46+
function Bretherton_transforms_threaded_mapreduce!(
47+
lin_cache,
48+
t,
49+
::Type{FT},
50+
) where {FT}
51+
# @info "Computing Bretherton_transforms! (threaded mapreduce)..."
52+
(; ᶠwb) = lin_cache
53+
(; ᶜx, ᶠx, ᶜz, ᶠz, ᶠxz, ᶜxz) = lin_cache
54+
(; ᶜρb_init_xz, ρfb_init_array, unit_integral, x_max, z_max) = lin_cache
55+
(; max_ikx, max_ikz, u₀) = lin_cache
56+
combine(ᶜpb, ᶜρb, ᶜub, ᶜvb) = (; ᶜpb, ᶜρb, ᶜub, ᶜvb)
57+
ᶜbretherton_fields =
58+
combine.(lin_cache.ᶜpb, lin_cache.ᶜρb, lin_cache.ᶜub, lin_cache.ᶜvb)
59+
60+
# TODO: could we, and is it advantageous to, combine
61+
# this into a single mapreduce call?
62+
ip = Iterators.product((-max_ikx):max_ikx, (-max_ikz):max_ikz)
63+
bc_add = (a, b) -> a .+ b
64+
ᶠwb .= ThreadsX.mapreduce(
65+
bc_add,
66+
ip;
67+
init = zeros(FT, axes(ᶠwb)),
68+
) do (ikx, ikz)
69+
(; pfb, ρfb, ufb, vfb, wfb) =
70+
Bretherton_transform_coeffs(lin_cache, ikx, ikz, t, FT)
71+
72+
# Fourier coefficient of ᶜρb_init (for current kx and kz)
73+
kx::FT = 2 * π / x_max * ikx
74+
kz::FT = 2 * π / (2 * z_max) * ikz
75+
76+
# Fourier factors, shifted by u₀ * t along the x-axis
77+
map(ᶠxz) do nt
78+
real(wfb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
79+
end
80+
end
81+
82+
bc_add = (a, b) -> a .+ b
83+
zeroᶜbretherton_fields =
84+
zeros(eltype(ᶜbretherton_fields), axes(ᶜbretherton_fields))
85+
ᶜbretherton_fields .= ThreadsX.mapreduce(
86+
bc_add,
87+
ip;
88+
init = zeroᶜbretherton_fields,
89+
) do (ikx, ikz)
90+
(; pfb, ρfb, ufb, vfb, wfb) =
91+
Bretherton_transform_coeffs(lin_cache, ikx, ikz, t, FT)
92+
93+
# Fourier coefficient of ᶜρb_init (for current kx and kz)
94+
kx::FT = 2 * π / x_max * ikx
95+
kz::FT = 2 * π / (2 * z_max) * ikz
96+
97+
# Fourier factors, shifted by u₀ * t along the x-axis
98+
map(ᶜxz) do nt
99+
ᶜpb::FT = real(pfb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
100+
ᶜρb::FT =
101+
real(ρfb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
102+
ᶜub::FT =
103+
real(ufb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
104+
ᶜvb::FT =
105+
real(vfb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
106+
(; ᶜpb, ᶜρb, ᶜub, ᶜvb)
107+
end
108+
end
109+
110+
lin_cache.ᶜpb .= ᶜbretherton_fields.ᶜpb
111+
lin_cache.ᶜρb .= ᶜbretherton_fields.ᶜρb
112+
lin_cache.ᶜub .= ᶜbretherton_fields.ᶜub
113+
lin_cache.ᶜvb .= ᶜbretherton_fields.ᶜvb
114+
return nothing
47115
end
48116

49117
function Bretherton_transforms_original!(lin_cache, t, ::Type{FT}) where {FT}

0 commit comments

Comments
 (0)