Skip to content

Commit 6fa26df

Browse files
author
Charlie Kawczynski
committed
Remove DeviceIntervalTopology
1 parent 8a26a48 commit 6fa26df

File tree

8 files changed

+51
-38
lines changed

8 files changed

+51
-38
lines changed

ext/cuda/adapt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ Adapt.adapt_structure(to::CUDA.KernelAdaptor, space::Spaces.PointSpace) =
4242
Adapt.adapt_structure(
4343
to::CUDA.KernelAdaptor,
4444
topology::Topologies.IntervalTopology,
45-
) = Topologies.DeviceIntervalTopology(topology.boundaries)
45+
) = Topologies.IntervalTopology(
46+
Adapt.adapt(to, ClimaComms.SingletonCommsContext(ClimaComms.device(topology.context))),
47+
Adapt.adapt(to, topology.mesh),
48+
Adapt.adapt(to, topology.boundaries),
49+
)
4650

4751
Adapt.adapt_structure(
4852
to::CUDA.KernelAdaptor,

src/Domains/Domains.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ struct IntervalDomain{CT, B} <: AbstractDomain where {
3232
} where {FT}
3333
coord_min::CT
3434
coord_max::CT
35-
boundary_names::B
3635
end
3736

38-
isperiodic(domain::IntervalDomain) = isnothing(domain.boundary_names)
39-
boundary_names(domain::IntervalDomain) =
40-
isperiodic(domain) ? () : unique(domain.boundary_names)
37+
isperiodic(::IntervalDomain{CT, B}) where {CT, B} = B == nothing
38+
boundary_names(domain::IntervalDomain{CT, B}) where {CT, B} =
39+
isperiodic(domain) ? () : unique(B)
40+
boundary_names_type(::IntervalDomain{CT, B}) where {CT, B} = B
4141

4242
"""
4343
IntervalDomain(coord⁻, coord⁺; periodic=true)
@@ -60,7 +60,13 @@ function IntervalDomain(
6060
),
6161
)
6262
end
63-
IntervalDomain(promote(coord_min, coord_max)..., boundary_names)
63+
c = promote(coord_min, coord_max)
64+
boundary_names = if isnothing(boundary_names)
65+
boundary_names
66+
else
67+
Tuple(boundary_names)
68+
end
69+
IntervalDomain{eltype(c), boundary_names}(c...)
6470
end
6571
IntervalDomain(coords::IntervalSets.ClosedInterval; kwargs...) =
6672
IntervalDomain(coords.left, coords.right; kwargs...)
@@ -95,7 +101,7 @@ function print_interval(io::IO, domain::IntervalDomain{CT}) where {CT}
95101
if isperiodic(domain)
96102
print(io, "(periodic)")
97103
else
98-
print(io, domain.boundary_names)
104+
print(io, boundary_names_type(domain))
99105
end
100106
end
101107
function Base.show(io::IO, domain::IntervalDomain)

src/Grids/finitedifference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function _FiniteDifferenceGrid(topology::Topologies.IntervalTopology)
7373
)
7474

7575
return FiniteDifferenceGrid(
76-
topology,
76+
Adapt.adapt(ArrayType, topology),
7777
global_geometry,
7878
Adapt.adapt(ArrayType, center_local_geometry),
7979
Adapt.adapt(ArrayType, face_local_geometry),

src/InputOutput/writers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,10 @@ function write_new!(
166166
)
167167
write_attribute(group, "coord_min", Geometry.component(domain.coord_min, 1))
168168
write_attribute(group, "coord_max", Geometry.component(domain.coord_max, 1))
169-
!isnothing(domain.boundary_names) && write_attribute(
169+
!isnothing(Domains.boundary_names(domain)) && write_attribute(
170170
group,
171171
"boundary_names",
172-
[String(bname) for bname in domain.boundary_names],
172+
[String(bname) for bname in Domains.boundary_names(domain)],
173173
)
174174
return name
175175
end

src/Meshes/interval.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ struct IntervalMesh{S, I <: IntervalDomain, V <: AbstractVector, M} <:
2626
meta::M
2727
end
2828

29+
using Adapt
30+
Adapt.@adapt_structure IntervalMesh
31+
2932
# implies isequal
3033
Base.:(==)(mesh1::IntervalMesh, mesh2::IntervalMesh) =
3134
mesh1.domain == mesh2.domain && mesh1.faces == mesh2.faces
@@ -85,9 +88,9 @@ end
8588
function boundary_face_name(mesh::IntervalMesh, elem::Integer, face)
8689
if !Domains.isperiodic(mesh.domain)
8790
if elem == 1 && face == 1
88-
return mesh.domain.boundary_names[1]
91+
return Domains.boundary_names(mesh.domain)[1]
8992
elseif elem == nelements(mesh) && face == 2
90-
return mesh.domain.boundary_names[2]
93+
return Domains.boundary_names(mesh.domain)[2]
9194
end
9295
end
9396
return nothing
@@ -468,7 +471,7 @@ function truncate_mesh(
468471
new_domain = IntervalDomain(
469472
z_bottom,
470473
Geometry.ZPoint{FT}(z_top),
471-
boundary_names = trunc_domain.boundary_names,
474+
boundary_names = Domains.boundary_names_type(trunc_domain),
472475
)
473476
return IntervalMesh(new_domain, new_stretch; nelems = new_nelems)
474477
end

src/Topologies/interval.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,10 @@ end
1818

1919
Adapt.@adapt_structure IntervalTopology
2020

21-
## gpu
22-
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
23-
boundaries::B
24-
end
25-
26-
ClimaComms.context(topology::DeviceIntervalTopology) = DeviceSideContext()
27-
ClimaComms.device(topology::DeviceIntervalTopology) = DeviceSideDevice()
21+
ClimaComms.context(topology::IntervalTopology) = topology.context
22+
ClimaComms.device(topology::IntervalTopology) =
23+
ClimaComms.device(topology.context)
2824

29-
ClimaComms.device(topology::IntervalTopology) = topology.context.device
3025
ClimaComms.array_type(topology::IntervalTopology) =
3126
ClimaComms.array_type(topology.context.device)
3227

@@ -46,14 +41,19 @@ function _IntervalTopology(
4641
)
4742
# currently only support SingletonCommsContext
4843
@assert context isa ClimaComms.SingletonCommsContext
49-
if Domains.isperiodic(mesh.domain)
44+
domain = mesh.domain
45+
if Domains.isperiodic(domain)
5046
boundaries = NamedTuple()
51-
elseif mesh.domain.boundary_names[1] == mesh.domain.boundary_names[2]
52-
boundaries = NamedTuple{(mesh.domain.boundary_names[1],)}(1)
5347
else
54-
boundaries = NamedTuple{mesh.domain.boundary_names}((1, 2))
48+
bn = Domains.boundary_names_type(domain)
49+
boundaries = if bn[1] == bn[2]
50+
NamedTuple{(bn[1],)}(1)
51+
else
52+
NamedTuple{bn}((1, 2))
53+
end
5554
end
56-
IntervalTopology(context, mesh, boundaries)
55+
ArrayType = ClimaComms.array_type(ClimaComms.device(context))
56+
IntervalTopology(context, Adapt.adapt(ArrayType, mesh), boundaries)
5757
end
5858

5959
IntervalTopology(device::ClimaComms.AbstractDevice, mesh::Meshes.IntervalMesh) =

test/Operators/finitedifference/unit_columnwise.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,11 @@ import ClimaCore.MatrixFields
3030
import ClimaCore.Spaces
3131
import ClimaCore.Fields
3232

33-
Operators.fd_shmem_is_supported(bc::Base.Broadcast.Broadcasted) = false
34-
ClimaCore.Operators.use_fd_shmem() = false
33+
# Operators.fd_shmem_is_supported(bc::Base.Broadcast.Broadcasted) = false
34+
# ClimaCore.Operators.use_fd_shmem() = false
3535
# The existing implementation limits our ability to apply
3636
# the same expressions from within kernels
37-
ClimaComms.device(topology::Topologies.DeviceIntervalTopology) =
38-
ClimaComms.CUDADevice()
39-
Fields.error_mismatched_spaces(::Type, ::Type) = nothing # causes unsupported dynamic function invocation
37+
# Fields.error_mismatched_spaces(::Type, ::Type) = nothing # causes unsupported dynamic function invocation
4038

4139
const C1 = Geometry.Covariant1Vector
4240
const C2 = Geometry.Covariant2Vector
@@ -91,12 +89,14 @@ Base.Broadcast.broadcastable(x::RayleighSponge) = tuple(x)
9189

9290
function rayleigh_sponge_tendency_uₕ(ᶜuₕ, s)
9391
s isa Nothing && return NullBroadcasted()
94-
(; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ))
95-
zmax = z_max(axes(ᶠz))
92+
ᶜz = Fields.coordinate_field(axes(ᶜuₕ)).z
93+
ᶠz = Fields.coordinate_field(Spaces.face_space(axes(ᶜuₕ))).z
94+
# (; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ))
95+
zmax = Spaces.z_max(axes(ᶠz))
9696
return @. lazy(-β_rayleigh_uₕ(s, ᶜz, zmax) * ᶜuₕ)
9797
end
9898

99-
function compute_kinetic(uₕ::Fields.Field, uᵥ::Fields.Field)
99+
function compute_kinetic(uₕ, uᵥ)
100100
@assert eltype(uₕ) <: Union{C1, C2, C12}
101101
@assert eltype(uᵥ) <: C3
102102
FT = Spaces.undertype(axes(uₕ))
@@ -167,9 +167,9 @@ function ᶜimplicit_tendency_bc(ᶜY, ᶠY, p, t)
167167
ᶠu³ = @. lazy(ᶠuₕ³ + CT3(ᶠu₃))
168168
tend_ρ_1 = @. lazy(ᶜdivᵥ(ᶠwinterp(ᶜJ, ᶜρ) * ᶠuₕ³))
169169
tend_ρe_tot_1 = vertical_transport(ᶜρ, ᶠu³, ᶜh_tot, dt, Val(:none))
170-
ᶜuₕ₀ = (zero(eltype(ᶜuₕ)),)
170+
ᶜuₕ₀ = rayleigh_sponge_tendency_uₕ(ᶜuₕ, rayleigh_sponge)
171171

172-
return @. lazy(ᶜtendencies(-tend_ρ_1, - ᶜuₕ₀, tend_ρe_tot_1))
172+
return @. lazy(ᶜtendencies(-tend_ρ_1, ᶜuₕ₀, tend_ρe_tot_1))
173173
end
174174

175175
function ᶠimplicit_tendency_bc(ᶜY, ᶠY, p, t)

test/Spaces/unit_spaces.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ end
330330

331331
@static if on_gpu
332332
adapted_space = Adapt.adapt(CUDA.KernelAdaptor(), c_space)
333-
@test ClimaComms.context(adapted_space) == DeviceSideContext()
334-
@test ClimaComms.device(adapted_space) == DeviceSideDevice()
333+
@test ClimaComms.context(adapted_space) == ClimaComms.context(c_space)
334+
@test ClimaComms.device(adapted_space) == ClimaComms.device(c_space)
335335

336336
adapted_hspace = Adapt.adapt(CUDA.KernelAdaptor(), hspace)
337337
@test ClimaComms.context(adapted_hspace) == DeviceSideContext()

0 commit comments

Comments
 (0)