Skip to content

Commit 3f666f4

Browse files
Move in-kernel specific adapt functions to cuda ext
1 parent 70b7c71 commit 3f666f4

File tree

8 files changed

+50
-35
lines changed

8 files changed

+50
-35
lines changed

ext/ClimaCoreCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import ClimaCore.RecursiveApply:
2020
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
2121
import ClimaCore.DataLayouts: UniversalSize
2222

23+
include(joinpath("cuda", "adapt.jl"))
2324
include(joinpath("cuda", "cuda_utils.jl"))
2425
include(joinpath("cuda", "data_layouts.jl"))
2526
include(joinpath("cuda", "fields.jl"))

ext/cuda/adapt.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import CUDA, Adapt
2+
import ClimaCore
3+
import ClimaCore: Grids, Spaces, Topologies, Devices
4+
5+
Adapt.adapt_structure(
6+
to::CUDA.KernelAdaptor,
7+
grid::Grids.ExtrudedFiniteDifferenceGrid,
8+
) = Grids.DeviceExtrudedFiniteDifferenceGrid(
9+
Adapt.adapt(to, Grids.vertical_topology(grid)),
10+
Adapt.adapt(to, grid.horizontal_grid.quadrature_style),
11+
Adapt.adapt(to, grid.global_geometry),
12+
Adapt.adapt(to, grid.center_local_geometry),
13+
Adapt.adapt(to, grid.face_local_geometry),
14+
)
15+
16+
Adapt.adapt_structure(
17+
to::CUDA.KernelAdaptor,
18+
grid::Grids.FiniteDifferenceGrid,
19+
) = Grids.DeviceFiniteDifferenceGrid(
20+
Adapt.adapt(to, grid.topology),
21+
Adapt.adapt(to, grid.global_geometry),
22+
Adapt.adapt(to, grid.center_local_geometry),
23+
Adapt.adapt(to, grid.face_local_geometry),
24+
)
25+
26+
Adapt.adapt_structure(
27+
to::CUDA.KernelAdaptor,
28+
grid::Grids.SpectralElementGrid2D,
29+
) = Grids.DeviceSpectralElementGrid2D(
30+
Adapt.adapt(to, grid.quadrature_style),
31+
Adapt.adapt(to, grid.global_geometry),
32+
Adapt.adapt(to, grid.local_geometry),
33+
)
34+
35+
Adapt.adapt_structure(to::CUDA.KernelAdaptor, space::Spaces.PointSpace) =
36+
Spaces.PointSpace(
37+
ClimaCore.DeviceSideContext(),
38+
Adapt.adapt(to, Spaces.local_geometry_data(space)),
39+
)
40+
41+
Adapt.adapt_structure(
42+
to::CUDA.KernelAdaptor,
43+
topology::Topologies.IntervalTopology,
44+
) = Topologies.DeviceIntervalTopology(topology.boundaries)

src/Grids/extruded.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,6 @@ local_geometry_type(
155155
::Type{DeviceExtrudedFiniteDifferenceGrid{VT, Q, GG, CLG, FLG}},
156156
) where {VT, Q, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts
157157

158-
Adapt.adapt_structure(to, grid::ExtrudedFiniteDifferenceGrid) =
159-
DeviceExtrudedFiniteDifferenceGrid(
160-
Adapt.adapt(to, vertical_topology(grid)),
161-
Adapt.adapt(to, grid.horizontal_grid.quadrature_style),
162-
Adapt.adapt(to, grid.global_geometry),
163-
Adapt.adapt(to, grid.center_local_geometry),
164-
Adapt.adapt(to, grid.face_local_geometry),
165-
)
166-
167158
quadrature_style(grid::DeviceExtrudedFiniteDifferenceGrid) =
168159
grid.quadrature_style
169160
vertical_topology(grid::DeviceExtrudedFiniteDifferenceGrid) =

src/Grids/finitedifference.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,6 @@ local_geometry_type(
190190
::Type{DeviceFiniteDifferenceGrid{T, GG, CLG, FLG}},
191191
) where {T, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts
192192

193-
Adapt.adapt_structure(to, grid::FiniteDifferenceGrid) =
194-
DeviceFiniteDifferenceGrid(
195-
Adapt.adapt(to, grid.topology),
196-
Adapt.adapt(to, grid.global_geometry),
197-
Adapt.adapt(to, grid.center_local_geometry),
198-
Adapt.adapt(to, grid.face_local_geometry),
199-
)
200-
201193
topology(grid::DeviceFiniteDifferenceGrid) = grid.topology
202194
vertical_topology(grid::DeviceFiniteDifferenceGrid) = grid.topology
203195

src/Grids/spectralelement.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -597,13 +597,6 @@ end
597597
ClimaComms.context(grid::DeviceSpectralElementGrid2D) = DeviceSideContext()
598598
ClimaComms.device(grid::DeviceSpectralElementGrid2D) = DeviceSideDevice()
599599

600-
Adapt.adapt_structure(to, grid::SpectralElementGrid2D) =
601-
DeviceSpectralElementGrid2D(
602-
Adapt.adapt(to, grid.quadrature_style),
603-
Adapt.adapt(to, grid.global_geometry),
604-
Adapt.adapt(to, grid.local_geometry),
605-
)
606-
607600
## aliases
608601
const RectilinearSpectralElementGrid2D =
609602
SpectralElementGrid2D{<:Topologies.RectilinearTopology2D}

src/Spaces/pointspace.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ function PointSpace(
3838
return PointSpace(context, Adapt.adapt(ArrayType, local_geometry_data))
3939
end
4040

41-
42-
Adapt.adapt_structure(to, space::PointSpace) =
43-
PointSpace(DeviceSideContext(), Adapt.adapt(to, local_geometry_data(space)))
44-
4541
function PointSpace(
4642
context::ClimaComms.AbstractCommsContext,
4743
coord::Geometry.Abstract1DPoint{FT},

src/Topologies/interval.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ end
2020
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
2121
boundaries::B
2222
end
23-
Adapt.adapt_structure(to, topology::IntervalTopology) =
24-
DeviceIntervalTopology(topology.boundaries)
2523

2624
ClimaComms.context(topology::DeviceIntervalTopology) = DeviceSideContext()
2725
ClimaComms.device(topology::DeviceIntervalTopology) = DeviceSideDevice()

test/Spaces/unit_spaces.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ end
197197
@test length(Spaces.unique_nodes(hspace)) == 4
198198
@test length(Spaces.all_nodes(hspace)) == 4
199199

200-
if on_gpu
201-
adapted_space = adapt(c_space)(c_space)
200+
@static if on_gpu
201+
adapted_space = adapt(CUDA.KernelAdaptor(), c_space)
202202
@test ClimaComms.context(adapted_space) == DeviceSideContext()
203203
@test ClimaComms.device(adapted_space) == DeviceSideDevice()
204204

205-
adapted_hspace = adapt(hspace)(hspace)
205+
adapted_hspace = adapt(CUDA.KernelAdaptor(), hspace)
206206
@test ClimaComms.context(adapted_hspace) == DeviceSideContext()
207207
@test ClimaComms.device(adapted_hspace) == DeviceSideDevice()
208208
end
@@ -244,8 +244,8 @@ end
244244
local_geometry_slab = slab(Spaces.local_geometry_data(space), 1)
245245
dss_weights_slab = slab(Spaces.local_dss_weights(space), 1)
246246

247-
if on_gpu
248-
adapted_space = adapt(space)(space)
247+
@static if on_gpu
248+
adapted_space = adapt(CUDA.KernelAdaptor(), space)
249249
@test ClimaComms.context(adapted_space) == DeviceSideContext()
250250
@test ClimaComms.device(adapted_space) == DeviceSideDevice()
251251
end

0 commit comments

Comments
 (0)