Skip to content

Commit a83ceb7

Browse files
Merge pull request #2114 from CliMA/ck/adapt_cpu_gpu
Define cpu<->gpu conversions
2 parents 93f58c7 + 6ff39ad commit a83ceb7

File tree

7 files changed

+168
-8
lines changed

7 files changed

+168
-8
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ ClimaCoreCUDAExt = "CUDA"
3737
KrylovExt = "Krylov"
3838

3939
[compat]
40-
Adapt = "3, 4"
40+
Adapt = "3.2.0, 4"
4141
Aqua = "0.8"
4242
ArgParse = "1"
4343
AssociatedLegendrePolynomials = "1"
@@ -68,7 +68,7 @@ NVTX = "0.3"
6868
PkgVersion = "0.1, 0.2, 0.3"
6969
PrettyTables = "2"
7070
Random = "1"
71-
RecursiveArrayTools = "3.1"
71+
RecursiveArrayTools = "3.2"
7272
RootSolvers = "0.3, 0.4"
7373
SafeTestsets = "0.1"
7474
SparseArrays = "1"

src/Fields/fieldvector.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ struct FieldVector{T, M} <: BlockArrays.AbstractBlockVector{T}
2323
end
2424
FieldVector{T}(values::M) where {T, M} = FieldVector{T, M}(values)
2525

26+
function Adapt.adapt_structure(to, fv::FieldVector)
27+
pn = propertynames(fv)
28+
vals = map(key -> Adapt.adapt(to, getproperty(fv, key)), pn)
29+
return FieldVector(; NamedTuple{pn}(vals)...)
30+
end
2631

2732
"""
2833
Fields.ScalarWrapper(val) <: AbstractArray{T,0}

src/Grids/extruded.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ mutable struct ExtrudedFiniteDifferenceGrid{
3838
face_local_geometry::FLG
3939
end
4040

41+
Adapt.@adapt_structure ExtrudedFiniteDifferenceGrid
42+
4143
local_geometry_type(
4244
::Type{ExtrudedFiniteDifferenceGrid{H, V, A, GG, CLG, FLG}},
4345
) where {H, V, A, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts

src/Grids/finitedifference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ mutable struct FiniteDifferenceGrid{
4040
center_local_geometry::CLG
4141
face_local_geometry::FLG
4242
end
43-
43+
Adapt.@adapt_structure FiniteDifferenceGrid
4444

4545
function FiniteDifferenceGrid(topology::Topologies.IntervalTopology)
4646
get!(Cache.OBJECT_CACHE, (FiniteDifferenceGrid, topology)) do

src/Grids/spectralelement.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ mutable struct SpectralElementGrid1D{
2121
dss_weights::D
2222
end
2323

24+
Adapt.@adapt_structure SpectralElementGrid1D
25+
2426
local_geometry_type(
2527
::Type{SpectralElementGrid1D{T, Q, GG, LG}},
2628
) where {T, Q, GG, LG} = eltype(LG) # calls eltype from DataLayouts
@@ -140,6 +142,8 @@ mutable struct SpectralElementGrid2D{
140142
enable_bubble::Bool
141143
end
142144

145+
Adapt.@adapt_structure SpectralElementGrid2D
146+
143147
local_geometry_type(
144148
::Type{SpectralElementGrid2D{T, Q, GG, LG, D, IS, BS}},
145149
) where {T, Q, GG, LG, D, IS, BS} = eltype(LG) # calls eltype from DataLayouts

test/Fields/unit_field.jl

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#=
22
julia --check-bounds=yes --project
3-
julia --project
3+
julia --project=.buildkite
44
using Revise; include(joinpath("test", "Fields", "unit_field.jl"))
55
=#
66
using Test
@@ -702,6 +702,155 @@ end
702702
nothing
703703
end
704704

705+
using ClimaCore.CommonSpaces
706+
using ClimaCore.Grids
707+
using Adapt
708+
709+
function test_adapt(cpu_space_in)
710+
test_adapt_space(cpu_space_in)
711+
cpu_f_in = Fields.Field(Float64, cpu_space_in)
712+
cpu_f_out = Adapt.adapt(Array, cpu_f_in)
713+
@test parent(Spaces.local_geometry_data(axes(cpu_f_out))) isa Array
714+
@test parent(Fields.field_values(cpu_f_out)) isa Array
715+
716+
@static if ClimaComms.device() isa ClimaComms.CUDADevice
717+
# cpu -> gpu
718+
gpu_f_out = Adapt.adapt(CUDA.CuArray, cpu_f_in)
719+
@test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray
720+
# gpu -> gpu
721+
cpu_f_out = Adapt.adapt(Array, gpu_f_out)
722+
@test parent(Fields.field_values(cpu_f_out)) isa Array
723+
end
724+
end
725+
726+
function test_adapt_fieldvector(fv_in)
727+
cpu_fv_out = Adapt.adapt(Array, fv_in)
728+
@test parent(Spaces.local_geometry_data(axes(cpu_fv_out.c))) isa Array
729+
@test parent(Spaces.local_geometry_data(axes(cpu_fv_out.f))) isa Array
730+
@test parent(Fields.field_values(cpu_fv_out.c)) isa Array
731+
@test parent(Fields.field_values(cpu_fv_out.f)) isa Array
732+
733+
@static if ClimaComms.device() isa ClimaComms.CUDADevice
734+
# cpu -> gpu
735+
gpu_fv_out = Adapt.adapt(CUDA.CuArray, cpu_fv_out)
736+
@test parent(Fields.field_values(gpu_fv_out.c)) isa CUDA.CuArray
737+
@test parent(Fields.field_values(gpu_fv_out.f)) isa CUDA.CuArray
738+
# gpu -> gpu
739+
cpu_fv_out = Adapt.adapt(Array, gpu_fv_out)
740+
@test parent(Fields.field_values(cpu_fv_out.c)) isa Array
741+
@test parent(Fields.field_values(cpu_fv_out.f)) isa Array
742+
end
743+
end
744+
745+
function test_adapt_space(cpu_space_in)
746+
# cpu -> cpu
747+
cpu_space_out = Adapt.adapt(Array, cpu_space_in)
748+
@test parent(Spaces.local_geometry_data(cpu_space_out)) isa Array
749+
750+
@static if ClimaComms.device() isa ClimaComms.CUDADevice
751+
# cpu -> gpu
752+
gpu_space_out = Adapt.adapt(CUDA.CuArray, cpu_space_in)
753+
@test parent(Spaces.local_geometry_data(gpu_space_out)) isa CUDA.CuArray
754+
# gpu -> gpu
755+
cpu_space_out = Adapt.adapt(Array, gpu_space_out)
756+
@test parent(Spaces.local_geometry_data(cpu_space_out)) isa Array
757+
end
758+
end
759+
760+
@testset "Test Adapt" begin
761+
space = ExtrudedCubedSphereSpace(;
762+
device = ClimaComms.CPUSingleThreaded(),
763+
z_elem = 10,
764+
z_min = 0,
765+
z_max = 1,
766+
radius = 10,
767+
h_elem = 10,
768+
n_quad_points = 4,
769+
staggering = Grids.CellCenter(),
770+
)
771+
test_adapt(space)
772+
773+
space = CubedSphereSpace(;
774+
device = ClimaComms.CPUSingleThreaded(),
775+
radius = 10,
776+
n_quad_points = 4,
777+
h_elem = 10,
778+
)
779+
test_adapt(space)
780+
781+
space = ColumnSpace(;
782+
device = ClimaComms.CPUSingleThreaded(),
783+
z_elem = 10,
784+
z_min = 0,
785+
z_max = 10,
786+
staggering = CellCenter(),
787+
)
788+
test_adapt(space)
789+
790+
space = Box3DSpace(;
791+
device = ClimaComms.CPUSingleThreaded(),
792+
z_elem = 10,
793+
x_min = 0,
794+
x_max = 1,
795+
y_min = 0,
796+
y_max = 1,
797+
z_min = 0,
798+
z_max = 10,
799+
periodic_x = false,
800+
periodic_y = false,
801+
n_quad_points = 4,
802+
x_elem = 3,
803+
y_elem = 4,
804+
staggering = CellCenter(),
805+
)
806+
test_adapt(space)
807+
808+
space = SliceXZSpace(;
809+
device = ClimaComms.CPUSingleThreaded(),
810+
z_elem = 10,
811+
x_min = 0,
812+
x_max = 1,
813+
z_min = 0,
814+
z_max = 1,
815+
periodic_x = false,
816+
n_quad_points = 4,
817+
x_elem = 4,
818+
staggering = CellCenter(),
819+
)
820+
test_adapt(space)
821+
822+
space = RectangleXYSpace(;
823+
device = ClimaComms.CPUSingleThreaded(),
824+
x_min = 0,
825+
x_max = 1,
826+
y_min = 0,
827+
y_max = 1,
828+
periodic_x = false,
829+
periodic_y = false,
830+
n_quad_points = 4,
831+
x_elem = 3,
832+
y_elem = 4,
833+
)
834+
test_adapt(space)
835+
836+
# FieldVector
837+
cspace = ExtrudedCubedSphereSpace(;
838+
device = ClimaComms.CPUSingleThreaded(),
839+
z_elem = 10,
840+
z_min = 0,
841+
z_max = 1,
842+
radius = 10,
843+
h_elem = 10,
844+
n_quad_points = 4,
845+
staggering = Grids.CellCenter(),
846+
)
847+
fspace = Spaces.face_space(cspace)
848+
c = Fields.zeros(cspace)
849+
f = Fields.zeros(fspace)
850+
fv = Fields.FieldVector(; c, f)
851+
test_adapt_fieldvector(fv)
852+
end
853+
705854
@testset "Memoization of spaces" begin
706855
space1 = spectral_space_2D()
707856
space2 = spectral_space_2D()

test/Spaces/unit_spaces.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Revise; include(joinpath("test", "Spaces", "unit_spaces.jl"))
55
using Test
66
using ClimaComms
77
using StaticArrays, IntervalSets, LinearAlgebra
8-
using Adapt
8+
import Adapt
99
ClimaComms.@import_required_backends
1010

1111
import ClimaCore:
@@ -198,11 +198,11 @@ end
198198
@test length(Spaces.all_nodes(hspace)) == 4
199199

200200
@static if on_gpu
201-
adapted_space = adapt(CUDA.KernelAdaptor(), c_space)
201+
adapted_space = Adapt.adapt(CUDA.KernelAdaptor(), c_space)
202202
@test ClimaComms.context(adapted_space) == DeviceSideContext()
203203
@test ClimaComms.device(adapted_space) == DeviceSideDevice()
204204

205-
adapted_hspace = adapt(CUDA.KernelAdaptor(), hspace)
205+
adapted_hspace = Adapt.adapt(CUDA.KernelAdaptor(), hspace)
206206
@test ClimaComms.context(adapted_hspace) == DeviceSideContext()
207207
@test ClimaComms.device(adapted_hspace) == DeviceSideDevice()
208208
end
@@ -245,7 +245,7 @@ end
245245
dss_weights_slab = slab(Spaces.local_dss_weights(space), 1)
246246

247247
@static if on_gpu
248-
adapted_space = adapt(CUDA.KernelAdaptor(), space)
248+
adapted_space = Adapt.adapt(CUDA.KernelAdaptor(), space)
249249
@test ClimaComms.context(adapted_space) == DeviceSideContext()
250250
@test ClimaComms.device(adapted_space) == DeviceSideDevice()
251251
end

0 commit comments

Comments
 (0)