Skip to content

Commit cbb347b

Browse files
Merge #1228 #1263
1228: Fix dss_transform ambiguity aqua test r=charleskawczynski a=charleskawczynski This PR fixes a method ambiguity for `dss_transform`-- perhaps not too important, but 🤷🏻. 1263: Speed up inertial gravity wave examples r=charleskawczynski a=charleskawczynski CI has been slow lately, and failures due to unknown reasons (e.g., the recent [mac os timeout](https://github.com/CliMA/ClimaCore.jl/actions/runs/5071735121/jobs/9108493197)) is painful when we have long chains of serial tasks. I've been working on this branch for a while now, mostly while I'm waiting for CI to pass and I think it's finally ready. The main goal is to speed up the inertial gravity wave examples, which are very slow. This PR attempts to speed up the inertial gravity wave examples by accelerating the Bretherton computations using a threaded mapreduce. Supersedes #849. Co-authored-by: Charles Kawczynski <kawczynski.charles@gmail.com>
3 parents 3e32718 + 6a3672a + eed2e75 commit cbb347b

File tree

7 files changed

+158
-14
lines changed

7 files changed

+158
-14
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,7 @@ steps:
780780
TEST_NAME: "plane/inertial_gravity_wave"
781781
agents:
782782
slurm_cpus_per_task: 8
783+
slurm_mem: 20GB
783784

784785
- label: ":computer: stretched 2D plane inertial gravity wave"
785786
key: "cpu_stretch_inertial_gravity_wave"
@@ -792,6 +793,7 @@ steps:
792793
Z_STRETCH: "true"
793794
agents:
794795
slurm_cpus_per_task: 8
796+
slurm_mem: 20GB
795797

796798
- group: "Performance"
797799
steps:

examples/Manifest.toml

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.8.5"
44
manifest_format = "2.0"
5-
project_hash = "7ec2fabea202c1fbfc0bf694c476ce42ac582098"
5+
project_hash = "1e5a4795567d9d355e43d11e9245fcf683ab6bb8"
66

77
[[deps.AMD]]
88
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]
@@ -27,6 +27,11 @@ git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef"
2727
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2828
version = "3.6.1"
2929

30+
[[deps.ArgCheck]]
31+
git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
32+
uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
33+
version = "2.3.0"
34+
3035
[[deps.ArgTools]]
3136
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
3237
version = "1.1.1"
@@ -88,9 +93,20 @@ git-tree-sha1 = "6ef8fc1d77b60f41041d59ce61ef9eb41ed97a83"
8893
uuid = "aae01518-5342-5314-be14-df237901396f"
8994
version = "0.17.18"
9095

96+
[[deps.BangBang]]
97+
deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"]
98+
git-tree-sha1 = "54b00d1b93791f8e19e31584bd30f2cb6004614b"
99+
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
100+
version = "0.3.38"
101+
91102
[[deps.Base64]]
92103
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
93104

105+
[[deps.Baselet]]
106+
git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
107+
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
108+
version = "0.1.1"
109+
94110
[[deps.BitFlags]]
95111
git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d"
96112
uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35"
@@ -195,7 +211,7 @@ version = "0.4.2"
195211
deps = ["Adapt", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
196212
path = ".."
197213
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
198-
version = "0.10.33"
214+
version = "0.10.36"
199215

200216
[[deps.ClimaCorePlots]]
201217
deps = ["ClimaCore", "RecipesBase", "StaticArrays", "TriplotBase"]
@@ -204,7 +220,7 @@ uuid = "cf7c7e5a-b407-4c48-9047-11a94a308626"
204220
version = "0.2.5"
205221

206222
[[deps.ClimaCoreTempestRemap]]
207-
deps = ["ClimaComms", "ClimaCore", "Dates", "LinearAlgebra", "MPI", "NCDatasets", "PkgVersion", "TempestRemap_jll", "Test"]
223+
deps = ["ClimaComms", "ClimaCore", "Dates", "LinearAlgebra", "NCDatasets", "PkgVersion", "TempestRemap_jll", "Test"]
208224
path = "../lib/ClimaCoreTempestRemap"
209225
uuid = "d934ef94-cdd4-4710-83d6-720549644b70"
210226
version = "0.3.8"
@@ -290,6 +306,11 @@ git-tree-sha1 = "02d2316b7ffceff992f3096ae48c7829a8aa0638"
290306
uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657"
291307
version = "0.1.3"
292308

309+
[[deps.CompositionsBase]]
310+
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
311+
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
312+
version = "0.1.2"
313+
293314
[[deps.ConcurrentUtilities]]
294315
deps = ["Serialization", "Sockets"]
295316
git-tree-sha1 = "b306df2650947e9eb100ec125ff8c65ca2053d30"
@@ -345,6 +366,11 @@ version = "1.0.0"
345366
deps = ["Printf"]
346367
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
347368

369+
[[deps.DefineSingletons]]
370+
git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c"
371+
uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52"
372+
version = "0.1.2"
373+
348374
[[deps.DelimitedFiles]]
349375
deps = ["Mmap"]
350376
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
@@ -681,6 +707,11 @@ git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428"
681707
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
682708
version = "0.1.3"
683709

710+
[[deps.InitialValues]]
711+
git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3"
712+
uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c"
713+
version = "0.3.1"
714+
684715
[[deps.IntelOpenMP_jll]]
685716
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
686717
git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83"
@@ -1028,6 +1059,12 @@ git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
10281059
uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e"
10291060
version = "0.3.2"
10301061

1062+
[[deps.MicroCollections]]
1063+
deps = ["BangBang", "InitialValues", "Setfield"]
1064+
git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e"
1065+
uuid = "128add7d-3638-4c79-886c-908ea0c25c34"
1066+
version = "0.1.4"
1067+
10311068
[[deps.MicrosoftMPI_jll]]
10321069
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
10331070
git-tree-sha1 = "a8027af3d1743b3bfae34e54872359fdebb31422"
@@ -1361,6 +1398,12 @@ git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
13611398
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
13621399
version = "1.2.2"
13631400

1401+
[[deps.Referenceables]]
1402+
deps = ["Adapt"]
1403+
git-tree-sha1 = "e681d3bfa49cd46c3c161505caddf20f0e62aaa9"
1404+
uuid = "42d2dcc6-99eb-4e98-b66c-637b7d73030e"
1405+
version = "0.1.2"
1406+
13641407
[[deps.RelocatableFolders]]
13651408
deps = ["SHA", "Scratch"]
13661409
git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691"
@@ -1497,6 +1540,12 @@ git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880"
14971540
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
14981541
version = "2.2.0"
14991542

1543+
[[deps.SplittablesBase]]
1544+
deps = ["Setfield", "Test"]
1545+
git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5"
1546+
uuid = "171d559e-b47b-412a-8079-5efa626c420e"
1547+
version = "0.1.15"
1548+
15001549
[[deps.Static]]
15011550
deps = ["IfElse"]
15021551
git-tree-sha1 = "7f5a513baec6f122401abfc8e9c074fdac54f6c1"
@@ -1575,7 +1624,7 @@ uuid = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea"
15751624
version = "0.12.2"
15761625

15771626
[[deps.TempestRemap_jll]]
1578-
deps = ["Artifacts", "JLLWrappers", "Libdl", "NetCDF_jll", "OpenBLAS32_jll", "Pkg"]
1627+
deps = ["Artifacts", "HDF5_jll", "JLLWrappers", "Libdl", "NetCDF_jll", "OpenBLAS32_jll", "Pkg"]
15791628
git-tree-sha1 = "88c3818a492ad1a94b1aa440b01eab5d133209ff"
15801629
uuid = "8573a8c5-1df0-515e-a024-abad257ee284"
15811630
version = "2.1.6+1"
@@ -1602,6 +1651,12 @@ git-tree-sha1 = "c97f60dd4f2331e1a495527f80d242501d2f9865"
16021651
uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5"
16031652
version = "0.5.1"
16041653

1654+
[[deps.ThreadsX]]
1655+
deps = ["ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "Setfield", "SplittablesBase", "Transducers"]
1656+
git-tree-sha1 = "34e6bcf36b9ed5d56489600cf9f3c16843fa2aa2"
1657+
uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
1658+
version = "0.1.11"
1659+
16051660
[[deps.TimerOutputs]]
16061661
deps = ["ExprTools", "Printf"]
16071662
git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7"
@@ -1614,6 +1669,12 @@ git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769"
16141669
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
16151670
version = "0.9.13"
16161671

1672+
[[deps.Transducers]]
1673+
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
1674+
git-tree-sha1 = "25358a5f2384c490e98abd565ed321ffae2cbb37"
1675+
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
1676+
version = "0.4.76"
1677+
16171678
[[deps.TriangularSolve]]
16181679
deps = ["CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "LoopVectorization", "Polyester", "Static", "VectorizationBase"]
16191680
git-tree-sha1 = "31eedbc0b6d07c08a700e26d31298ac27ef330eb"

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2828
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2929
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
3030
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
31+
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
3132
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3233

3334
[compat]

examples/hybrid/plane/inertial_gravity_wave.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ function linear_solution_cache(ᶜlocal_geometry, ᶠlocal_geometry)
282282
@time "ρfb_init_coefs!" IGWU.ρfb_init_coefs!(FT, ρfb_init_array_params)
283283
(; ρfb_init_array, ᶜρb_init_xz, unit_integral) = ρfb_init_array_params
284284
max_ikx, max_ikz = (size(ρfb_init_array) .- 1) 2
285+
286+
get_xz(lg) = (; x = lg.coordinates.x, z = lg.coordinates.z)
287+
ᶠxz = get_xz.(ᶠlocal_geometry)
288+
ᶜxz = get_xz.(ᶜlocal_geometry)
289+
285290
ᶜp₀ = @. p₀(ᶜz)
286291
return (;
287292
# globals
@@ -305,6 +310,8 @@ function linear_solution_cache(ᶜlocal_geometry, ᶠlocal_geometry)
305310
ᶠx = ᶠlocal_geometry.coordinates.x,
306311
ᶜz,
307312
ᶠz,
313+
ᶜxz,
314+
ᶠxz,
308315

309316
# background state
310317
ᶜp₀,

examples/hybrid/plane/intertial_gravity_wave_utils.jl

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,83 @@ function ρfb_init_coefs!(::Type{FT}, params) where {FT}
3535
end
3636

3737
function Bretherton_transforms!(lin_cache, t, ::Type{FT}) where {FT}
38-
# Bretherton_transforms_partial_sums! is fastest because
39-
# we can multithread across
40-
# `Iterators.product((-max_ikx):max_ikx, (-max_ikz):max_ikz)`
41-
# and apply sums for center and face fields. Using mapreduce requires
42-
# two calls and, as a result in ~20 slower.
43-
44-
Bretherton_transforms_original!(lin_cache, t, FT)
45-
# Bretherton_transforms_partial_sums!(lin_cache, t, FT)
46-
# Bretherton_transforms_threaded_mapreduce!(lin_cache, t, FT)
38+
# Bretherton_transforms! is the most computationally
39+
# expensive part of this example and was therefore
40+
# optimized a bit.
41+
# Bretherton_transforms_original!(lin_cache, t, FT)
42+
Bretherton_transforms_threaded_mapreduce!(lin_cache, t, FT)
43+
end
44+
45+
import ThreadsX
46+
function Bretherton_transforms_threaded_mapreduce!(
47+
lin_cache,
48+
t,
49+
::Type{FT},
50+
) where {FT}
51+
# @info "Computing Bretherton_transforms! (threaded mapreduce)..."
52+
(; ᶠwb) = lin_cache
53+
(; ᶜx, ᶠx, ᶜz, ᶠz, ᶠxz, ᶜxz) = lin_cache
54+
(; ᶜρb_init_xz, ρfb_init_array, unit_integral, x_max, z_max) = lin_cache
55+
(; max_ikx, max_ikz, u₀) = lin_cache
56+
combine(ᶜpb, ᶜρb, ᶜub, ᶜvb) = (; ᶜpb, ᶜρb, ᶜub, ᶜvb)
57+
ᶜbretherton_fields =
58+
combine.(lin_cache.ᶜpb, lin_cache.ᶜρb, lin_cache.ᶜub, lin_cache.ᶜvb)
59+
60+
# TODO: could we, and is it advantageous to, combine
61+
# this into a single mapreduce call?
62+
ip = Iterators.product((-max_ikx):max_ikx, (-max_ikz):max_ikz)
63+
bc_add = (a, b) -> a .+ b
64+
ᶠwb .= ThreadsX.mapreduce(
65+
bc_add,
66+
ip;
67+
init = zeros(FT, axes(ᶠwb)),
68+
) do (ikx, ikz)
69+
(; pfb, ρfb, ufb, vfb, wfb) =
70+
Bretherton_transform_coeffs(lin_cache, ikx, ikz, t, FT)
71+
72+
# Fourier coefficient of ᶜρb_init (for current kx and kz)
73+
kx::FT = 2 * π / x_max * ikx
74+
kz::FT = 2 * π / (2 * z_max) * ikz
75+
76+
# Fourier factors, shifted by u₀ * t along the x-axis
77+
map(ᶠxz) do nt
78+
real(wfb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
79+
end
80+
end
81+
82+
bc_add = (a, b) -> a .+ b
83+
zeroᶜbretherton_fields =
84+
zeros(eltype(ᶜbretherton_fields), axes(ᶜbretherton_fields))
85+
ᶜbretherton_fields .= ThreadsX.mapreduce(
86+
bc_add,
87+
ip;
88+
init = zeroᶜbretherton_fields,
89+
) do (ikx, ikz)
90+
(; pfb, ρfb, ufb, vfb, wfb) =
91+
Bretherton_transform_coeffs(lin_cache, ikx, ikz, t, FT)
92+
93+
# Fourier coefficient of ᶜρb_init (for current kx and kz)
94+
kx::FT = 2 * π / x_max * ikx
95+
kz::FT = 2 * π / (2 * z_max) * ikz
96+
97+
# Fourier factors, shifted by u₀ * t along the x-axis
98+
map(ᶜxz) do nt
99+
ᶜpb::FT = real(pfb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
100+
ᶜρb::FT =
101+
real(ρfb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
102+
ᶜub::FT =
103+
real(ufb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
104+
ᶜvb::FT =
105+
real(vfb * exp(im * (kx * (nt.x - u₀ * t) + kz * nt.z)))
106+
(; ᶜpb, ᶜρb, ᶜub, ᶜvb)
107+
end
108+
end
109+
110+
lin_cache.ᶜpb .= ᶜbretherton_fields.ᶜpb
111+
lin_cache.ᶜρb .= ᶜbretherton_fields.ᶜρb
112+
lin_cache.ᶜub .= ᶜbretherton_fields.ᶜub
113+
lin_cache.ᶜvb .= ᶜbretherton_fields.ᶜvb
114+
return nothing
47115
end
48116

49117
function Bretherton_transforms_original!(lin_cache, t, ::Type{FT}) where {FT}

src/Spaces/dss_transform.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ end
6868
local_geometry::Geometry.LocalGeometry,
6969
weight,
7070
) where {T, N} = arg * weight
71+
@inline dss_transform(
72+
arg::Geometry.AxisTensor{T, N, <:Tuple{}},
73+
local_geometry::Geometry.LocalGeometry,
74+
weight,
75+
) where {T, N} = arg * weight
7176
@inline dss_transform(
7277
arg::Geometry.LocalVector,
7378
local_geometry::Geometry.LocalGeometry,

test/aqua.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using Aqua
2020
# then please lower the limit based on the new number of ambiguities.
2121
# We're trying to drive this number down to zero to reduce latency.
2222
@info "Number of method ambiguities: $(length(ambs))"
23-
@test length(ambs) 16
23+
@test length(ambs) 15
2424

2525
# returns a vector of all unbound args
2626
# ua = Aqua.detect_unbound_args_recursively(ClimaCore)

0 commit comments

Comments
 (0)