Skip to content

Commit 507ed4e

Browse files
Implement a columnwise shmem operator (#2328)
1 parent ff161f1 commit 507ed4e

File tree

13 files changed

+977
-9
lines changed

13 files changed

+977
-9
lines changed

.buildkite/Manifest.toml

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.9"
44
manifest_format = "2.0"
5-
project_hash = "6ab89829ea190189b0319a6f8e22b3515e5283c2"
5+
project_hash = "45a11f30c749324ab2ca9eb06366eb279b21cfa8"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "e2478490447631aedba0823d4d7a80b2cc8cdb32"
@@ -331,7 +331,7 @@ weakdeps = ["CUDA", "MPI"]
331331
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LazyBroadcast", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "UnrolledUtilities"]
332332
path = ".."
333333
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
334-
version = "0.14.29"
334+
version = "0.14.33"
335335
weakdeps = ["CUDA", "Krylov"]
336336

337337
[deps.ClimaCore.extensions]
@@ -356,6 +356,12 @@ path = "../lib/ClimaCoreVTK"
356356
uuid = "c8b6d40d-e815-466f-95ae-c48aefa668fa"
357357
version = "0.7.6"
358358

359+
[[deps.ClimaParams]]
360+
deps = ["TOML"]
361+
git-tree-sha1 = "acf6c80c7ad59fe9dac9cc49625d52f4b8e1f4b7"
362+
uuid = "5c42b081-d73a-476f-9059-fd94b934656c"
363+
version = "0.10.30"
364+
359365
[[deps.ClimaTimeSteppers]]
360366
deps = ["ClimaComms", "Colors", "DataStructures", "DiffEqBase", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "NVTX", "SciMLBase", "StaticArrays"]
361367
git-tree-sha1 = "e719705cf15fec895abcb547946131ffe83de4d7"
@@ -1516,6 +1522,11 @@ version = "400.902.209+0"
15161522
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
15171523
version = "1.2.0"
15181524

1525+
[[deps.NullBroadcasts]]
1526+
git-tree-sha1 = "343c7bb67d0a29ea5d7d2b3e945afe81e2862337"
1527+
uuid = "0d71be07-595a-4f89-9529-4065a4ab43a6"
1528+
version = "0.1.0"
1529+
15191530
[[deps.OffsetArrays]]
15201531
git-tree-sha1 = "a414039192a155fb38c4599a60110f0018c6ec82"
15211532
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
@@ -2186,6 +2197,16 @@ git-tree-sha1 = "43044b737fa70bc12f6105061d3da38f881a3e3c"
21862197
uuid = "b718987f-49a8-5099-9789-dcd902bef87d"
21872198
version = "1.0.2"
21882199

2200+
[[deps.Thermodynamics]]
2201+
deps = ["DocStringExtensions", "KernelAbstractions", "Random", "RootSolvers"]
2202+
git-tree-sha1 = "efe74e0344fd7fb68b831316055290d80a62d9c1"
2203+
uuid = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c"
2204+
version = "0.12.11"
2205+
weakdeps = ["ClimaParams"]
2206+
2207+
[deps.Thermodynamics.extensions]
2208+
CreateParametersExt = "ClimaParams"
2209+
21892210
[[deps.ThreadingUtilities]]
21902211
deps = ["ManualMemory"]
21912212
git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27"

.buildkite/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
1212
ClimaCorePlots = "cf7c7e5a-b407-4c48-9047-11a94a308626"
1313
ClimaCoreTempestRemap = "d934ef94-cdd4-4710-83d6-720549644b70"
1414
ClimaCoreVTK = "c8b6d40d-e815-466f-95ae-c48aefa668fa"
15+
ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"
1516
ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
1617
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1718
CountFlops = "1db9610d-79e1-487a-8d40-77f3295c7593"
@@ -32,6 +33,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
3233
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
3334
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
3435
NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
36+
NullBroadcasts = "0d71be07-595a-4f89-9529-4065a4ab43a6"
3537
OrdinaryDiffEqSSPRK = "669c94d9-1f4b-4b64-b377-1aa079aa2388"
3638
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
3739
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
@@ -52,6 +54,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
5254
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
5355
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
5456
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
57+
Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c"
5558
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
5659

5760
[compat]

.buildkite/pipeline.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,19 @@ steps:
624624
agents:
625625
slurm_gpus: 1
626626

627+
- label: "Unit: gpu columnwise"
628+
key: unit_gpu_columnwise
629+
command:
630+
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/finitedifference/unit_columnwise.jl"
631+
env:
632+
CLIMACOMMS_DEVICE: "CUDA"
633+
agents:
634+
slurm_gpus: 1
635+
636+
- label: "Unit: columnwise"
637+
key: unit_columnwise
638+
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/finitedifference/unit_columnwise.jl"
639+
627640
- label: "Unit: column"
628641
key: unit_column
629642
command:

ext/ClimaCoreCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ include(joinpath("cuda", "limiters.jl"))
3535
include(joinpath("cuda", "operators_sem_shmem.jl"))
3636
include(joinpath("cuda", "operators_fd_shmem_common.jl"))
3737
include(joinpath("cuda", "operators_fd_shmem.jl"))
38+
include(joinpath("cuda", "operators_columnwise.jl"))
3839
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
3940
include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl"))
4041
include(joinpath("cuda", "operators_spectral_element.jl"))

ext/cuda/data_layouts.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,43 @@ import CUDA
1515
parent_array_type(::Type{<:CUDA.CuArray{T, N, B} where {N}}) where {T, B} =
1616
CUDA.CuArray{T, N, B} where {N}
1717

18+
# allow on-device use of lazy broadcast objects
19+
parent_array_type(
20+
::Type{<:CUDA.CuDeviceArray{T, N, A} where {N}},
21+
) where {T, A} = CUDA.CuDeviceArray{T, N, A} where {N}
22+
1823
# Ensure that both parent array types have the same memory buffer type.
1924
promote_parent_array_type(
2025
::Type{CUDA.CuArray{T1, N, B} where {N}},
2126
::Type{CUDA.CuArray{T2, N, B} where {N}},
2227
) where {T1, T2, B} = CUDA.CuArray{promote_type(T1, T2), N, B} where {N}
2328

29+
# allow on-device use of lazy broadcast objects
30+
promote_parent_array_type(
31+
::Type{CUDA.CuDeviceArray{T1, N, B} where {N}},
32+
::Type{CUDA.CuDeviceArray{T2, N, B} where {N}},
33+
) where {T1, T2, B} = CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N}
34+
35+
# allow on-device use of lazy broadcast objects with different type params
36+
promote_parent_array_type(
37+
::Type{CUDA.CuDeviceArray{T1, N, B1} where {N}},
38+
::Type{CUDA.CuDeviceArray{T2, N, B2} where {N}},
39+
) where {T1, T2, B1, B2} =
40+
CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
41+
42+
# allow on-device use of lazy broadcast objects with different type params
43+
promote_parent_array_type(
44+
::Type{CUDA.CuDeviceArray{T1}},
45+
::Type{CUDA.CuDeviceArray{T2, N, B2} where {N}},
46+
) where {T1, T2, B2} =
47+
CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
48+
49+
promote_parent_array_type(
50+
::Type{CUDA.CuDeviceArray{T1, N, B1} where {N}},
51+
::Type{CUDA.CuDeviceArray{T2} where {N}},
52+
) where {T1, T2, B1} =
53+
CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
54+
2455
# Make `similar` accept our special `UnionAll` parent array type for CuArray.
2556
Base.similar(
2657
::Type{CUDA.CuArray{T, N′, B} where {N′}},

ext/cuda/operators_columnwise.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import ClimaCore.Operators:
2+
columnwise!,
3+
device_sync_threads,
4+
columnwise_kernel!,
5+
universal_index_columnwise,
6+
local_mem
7+
8+
device_sync_threads(device::ClimaComms.CUDADevice) = CUDA.sync_threads()
9+
10+
local_mem(
11+
device::ClimaComms.CUDADevice,
12+
::Type{T},
13+
::Val{dims},
14+
) where {T, dims} = CUDA.CuStaticSharedArray(T, dims)
15+
16+
function columnwise!(
17+
device::ClimaComms.CUDADevice,
18+
ᶜf::ᶜF,
19+
ᶠf::ᶠF,
20+
ᶜYₜ::Fields.Field,
21+
ᶠYₜ::Fields.Field,
22+
ᶜY::Fields.Field,
23+
ᶠY::Fields.Field,
24+
p,
25+
t,
26+
::Val{localmem_lg} = Val(true),
27+
::Val{localmem_state} = Val(true),
28+
) where {ᶜF, ᶠF, localmem_lg, localmem_state}
29+
ᶜspace = axes(ᶜY)
30+
ᶠspace = Spaces.face_space(ᶜspace)
31+
ᶠNv = Spaces.nlevels(ᶠspace)
32+
ᶜcf = Fields.coordinate_field(ᶜspace)
33+
us = DataLayouts.UniversalSize(Fields.field_values(ᶜcf))
34+
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
35+
nitems = Ni * Nj * 1 * ᶠNv * Nh
36+
kernel = CUDA.@cuda(
37+
always_inline = true,
38+
launch = false,
39+
columnwise_kernel!(
40+
device,
41+
ᶜf,
42+
ᶠf,
43+
ᶜYₜ,
44+
ᶠYₜ,
45+
ᶜY,
46+
ᶠY,
47+
p,
48+
t,
49+
nothing,
50+
Val(localmem_lg),
51+
Val(localmem_state),
52+
)
53+
)
54+
threads = (ᶠNv,)
55+
blocks = (Nh, Ni * Nj)
56+
kernel(
57+
device,
58+
ᶜf,
59+
ᶠf,
60+
ᶜYₜ,
61+
ᶠYₜ,
62+
ᶜY,
63+
ᶠY,
64+
p,
65+
t,
66+
nothing,
67+
Val(localmem_lg),
68+
Val(localmem_state);
69+
threads,
70+
blocks,
71+
)
72+
end
73+
74+
@inline function universal_index_columnwise(
75+
device::ClimaComms.CUDADevice,
76+
UI,
77+
us,
78+
)
79+
(v,) = CUDA.threadIdx()
80+
(h, ij) = CUDA.blockIdx()
81+
(Ni, Nj, _, _, _) = DataLayouts.universal_size(us)
82+
Ni * Nj < ij && return CartesianIndex((-1, -1, 1, -1, -1))
83+
@inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I
84+
return CartesianIndex((i, j, 1, v, h))
85+
end

ext/cuda/operators_finite_difference.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@ import ClimaCore.Operators: LeftBoundaryWindow, RightBoundaryWindow, Interior
1111

1212
struct CUDAColumnStencilStyle <: AbstractStencilStyle end
1313
struct CUDAWithShmemColumnStencilStyle <: AbstractStencilStyle end
14-
AbstractStencilStyle(bc, ::ClimaComms.CUDADevice) =
15-
Operators.any_fd_shmem_supported(bc) ? CUDAWithShmemColumnStencilStyle :
16-
CUDAColumnStencilStyle
14+
15+
AbstractStencilStyle(bc, ::ClimaComms.CUDADevice) = CUDAColumnStencilStyle
1716

1817
Base.Broadcast.BroadcastStyle(
1918
x::Operators.ColumnStencilStyle,
@@ -150,10 +149,7 @@ end
150149

151150
function copyto_stencil_kernel_shmem!(
152151
out,
153-
bc′::Union{
154-
StencilBroadcasted{CUDAWithShmemColumnStencilStyle},
155-
Broadcasted{CUDAWithShmemColumnStencilStyle},
156-
},
152+
bc′::Union{StencilBroadcasted, Broadcasted},
157153
space,
158154
bds,
159155
us,

src/DataLayouts/struct.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,19 @@ promote_parent_array_type(
116116
::Type{Array{T1}},
117117
::Type{MArray{S, T2}},
118118
) where {S, T1, T2} = MArray{S, promote_type(T1, T2)}
119+
# Ditch sizes (they're never actually used!)
120+
promote_parent_array_type(
121+
::Type{MArray{S1, T1}},
122+
::Type{MArray{S2, T2}},
123+
) where {S1, T1, S2, T2} = MArray{S, promote_type(T1, T2)} where {S}
124+
promote_parent_array_type(
125+
::Type{MArray{S1, T1} where {S1}},
126+
::Type{MArray{S2, T2}},
127+
) where {T1, S2, T2} = MArray{S, promote_type(T1, T2)} where {S}
128+
promote_parent_array_type(
129+
::Type{MArray{S1, T1}},
130+
::Type{MArray{S2, T2} where {S2}},
131+
) where {S1, T1, T2} = MArray{S, promote_type(T1, T2)} where {S}
119132

120133
"""
121134
StructArrays.bypass_constructor(T, args)

src/Fields/Fields.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ local_geometry_field(space::AbstractSpace) =
369369
Field(Spaces.local_geometry_data(space), space)
370370
local_geometry_field(field::Field) = local_geometry_field(axes(field))
371371

372+
Fields.local_geometry_field(bc::Base.Broadcast.Broadcasted) =
373+
Fields.local_geometry_field(axes(bc))
374+
372375
"""
373376
Δz_field(field::Field)
374377
Δz_field(space::AbstractSpace)

src/Grids/extruded.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ struct DeviceExtrudedFiniteDifferenceGrid{VT, Q, GG, CLG, FLG} <:
157157
face_local_geometry::FLG
158158
end
159159

160+
# Specialize to allow on-device call of `device` for `DeviceExtrudedFiniteDifferenceGrid`
161+
ClimaComms.device(grid::DeviceExtrudedFiniteDifferenceGrid) =
162+
ClimaComms.device(vertical_topology(grid))
163+
160164
local_geometry_type(
161165
::Type{DeviceExtrudedFiniteDifferenceGrid{VT, Q, GG, CLG, FLG}},
162166
) where {VT, Q, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts

0 commit comments

Comments
 (0)