Skip to content

Remove DeviceIntervalTopology #2343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/APIs/domains_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ Domains.SphereDomain

```@docs
Domains.boundary_names
Domains.unique_boundary_names
```
9 changes: 8 additions & 1 deletion ext/cuda/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@ Adapt.adapt_structure(to::CUDA.KernelAdaptor, space::Spaces.PointSpace) =
Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
topology::Topologies.IntervalTopology,
) = Topologies.DeviceIntervalTopology(topology.boundaries)
) = Topologies.IntervalTopology(
Adapt.adapt(
to,
ClimaComms.SingletonCommsContext(ClimaComms.device(topology.context)),
),
Adapt.adapt(to, topology.mesh),
Adapt.adapt(to, topology.boundaries),
)

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
Expand Down
40 changes: 26 additions & 14 deletions src/Domains/Domains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@ float_type(domain::AbstractDomain) = float_type(coordinate_type(domain))
"""
boundary_names(obj::Union{AbstractDomain, AbstractMesh, AbstractTopology})

A tuple or vector of unique boundary names of a spatial domain.
The boundary names passed to the IntervalDomain (a tuple, or `nothing`).
"""
function boundary_names end

struct IntervalDomain{CT, B} <: AbstractDomain where {
CT <: Geometry.Abstract1DPoint{FT},
B <: BCTagType,
} where {FT}
"""
unique_boundary_names(obj::Union{AbstractDomain, AbstractMesh, AbstractTopology})

A tuple or vector of unique boundary names of a spatial domain.
"""
function unique_boundary_names end

struct IntervalDomain{CT, B} <:
AbstractDomain where {CT <: Geometry.Abstract1DPoint{FT}, B} where {FT}
coord_min::CT
coord_max::CT
boundary_names::B
end

isperiodic(domain::IntervalDomain) = isnothing(domain.boundary_names)
boundary_names(domain::IntervalDomain) =
isperiodic(domain) ? () : unique(domain.boundary_names)
isperiodic(::IntervalDomain{CT, B}) where {CT, B} = B == nothing
unique_boundary_names(domain::IntervalDomain{CT, B}) where {CT, B} =
isperiodic(domain) ? Symbol[] : unique(B)
boundary_names(::IntervalDomain{CT, B}) where {CT, B} = B

"""
IntervalDomain(coord⁻, coord⁺; periodic=true)
Expand All @@ -60,7 +65,13 @@ function IntervalDomain(
),
)
end
IntervalDomain(promote(coord_min, coord_max)..., boundary_names)
c = promote(coord_min, coord_max)
boundary_names = if isnothing(boundary_names)
boundary_names
else
Tuple(boundary_names)
end
IntervalDomain{eltype(c), boundary_names}(c...)
end
IntervalDomain(coords::IntervalSets.ClosedInterval; kwargs...) =
IntervalDomain(coords.left, coords.right; kwargs...)
Expand Down Expand Up @@ -95,7 +106,7 @@ function print_interval(io::IO, domain::IntervalDomain{CT}) where {CT}
if isperiodic(domain)
print(io, "(periodic)")
else
print(io, domain.boundary_names)
print(io, boundary_names(domain))
end
end
function Base.show(io::IO, domain::IntervalDomain)
Expand All @@ -111,10 +122,10 @@ end
Base.:*(interval1::IntervalDomain, interval2::IntervalDomain) =
RectangleDomain(interval1, interval2)

boundary_names(domain::RectangleDomain) = unique(
unique_boundary_names(domain::RectangleDomain) = unique(
Symbol[
boundary_names(domain.interval1)...,
boundary_names(domain.interval2)...,
unique_boundary_names(domain.interval1)...,
unique_boundary_names(domain.interval2)...,
],
)::Vector{Symbol}

Expand Down Expand Up @@ -171,6 +182,7 @@ Base.show(io::IO, domain::SphereDomain) =
print(io, nameof(typeof(domain)), ": radius = ", domain.radius)

boundary_names(::SphereDomain) = ()
unique_boundary_names(::SphereDomain) = Symbol[]
coordinate_type(::SphereDomain{FT}) where {FT} = Geometry.Cartesian123Point{FT}

end # module
2 changes: 1 addition & 1 deletion src/Grids/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function _FiniteDifferenceGrid(topology::Topologies.IntervalTopology)
)

return FiniteDifferenceGrid(
topology,
Adapt.adapt(ArrayType, topology),
global_geometry,
Adapt.adapt(ArrayType, center_local_geometry),
Adapt.adapt(ArrayType, face_local_geometry),
Expand Down
11 changes: 8 additions & 3 deletions src/InputOutput/writers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ function write_new!(
)
write_attribute(group, "coord_min", Geometry.component(domain.coord_min, 1))
write_attribute(group, "coord_max", Geometry.component(domain.coord_max, 1))
!isnothing(domain.boundary_names) && write_attribute(
!isnothing(Domains.boundary_names(domain)) && write_attribute(
group,
"boundary_names",
[String(bname) for bname in domain.boundary_names],
[String(bname) for bname in Domains.boundary_names(domain)],
)
return name
end
Expand Down Expand Up @@ -209,10 +209,15 @@ function write_new!(
write_attribute(group, "faces_type", "Range")
else
write_attribute(group, "faces_type", "Array")
faces = if ClimaComms.device(writer.context) isa ClimaComms.AbstractCPUDevice
mesh.faces
else
Array(mesh.faces)
end
write_attribute(
group,
"faces",
[getfield(mesh.faces[i], 1) for i in 1:length(mesh.faces)],
[getfield(faces[i], 1) for i in 1:length(faces)],
)
end
(; stretch) = mesh
Expand Down
1 change: 1 addition & 0 deletions src/Meshes/Meshes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import ..Domains:
RectangleDomain,
SphereDomain,
boundary_names,
unique_boundary_names,
coordinate_type
import ..Geometry
import SparseArrays, CubedSphere, LinearAlgebra, StaticArrays
Expand Down
1 change: 1 addition & 0 deletions src/Meshes/common.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
boundary_names(mesh::AbstractMesh) = boundary_names(domain(mesh))
unique_boundary_names(mesh::AbstractMesh) = unique_boundary_names(domain(mesh))
coordinate_type(mesh::AbstractMesh) = coordinate_type(domain(mesh))

"""
Expand Down
9 changes: 6 additions & 3 deletions src/Meshes/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ struct IntervalMesh{S, I <: IntervalDomain, V <: AbstractVector, M} <:
meta::M
end

using Adapt
Adapt.@adapt_structure IntervalMesh

# implies isequal
Base.:(==)(mesh1::IntervalMesh, mesh2::IntervalMesh) =
mesh1.domain == mesh2.domain && mesh1.faces == mesh2.faces
Expand Down Expand Up @@ -85,9 +88,9 @@ end
function boundary_face_name(mesh::IntervalMesh, elem::Integer, face)
if !Domains.isperiodic(mesh.domain)
if elem == 1 && face == 1
return mesh.domain.boundary_names[1]
return Domains.unique_boundary_names(mesh.domain)[1]
elseif elem == nelements(mesh) && face == 2
return mesh.domain.boundary_names[2]
return Domains.unique_boundary_names(mesh.domain)[2]
end
end
return nothing
Expand Down Expand Up @@ -468,7 +471,7 @@ function truncate_mesh(
new_domain = IntervalDomain(
z_bottom,
Geometry.ZPoint{FT}(z_top),
boundary_names = trunc_domain.boundary_names,
boundary_names = Domains.boundary_names(trunc_domain),
)
return IntervalMesh(new_domain, new_stretch; nelems = new_nelems)
end
1 change: 1 addition & 0 deletions src/Remapping/Remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import ..DataLayouts,
..Hypsography
import ClimaCore.Utilities: half
import ClimaCore.Spaces: cuda_synchronize
import Adapt

using ..RecursiveApply

Expand Down
25 changes: 21 additions & 4 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,16 @@ function _Remapper(
)
num_dims = num_hdims
else
device = ClimaComms.device(space)
cpu_space = if device isa ClimaComms.AbstractCPUDevice
space
else
Adapt.adapt(Array, space)
end
vert_interpolation_weights =
ArrayType(vertical_interpolation_weights(space, target_zcoords))
ArrayType(vertical_interpolation_weights(cpu_space, target_zcoords))
vert_bounding_indices =
ArrayType(vertical_bounding_indices(space, target_zcoords))
ArrayType(vertical_bounding_indices(cpu_space, target_zcoords))

# We have to add one extra dimension with respect to the bitmask/local_horiz_indices
# because we are going to store the values for the columns
Expand Down Expand Up @@ -463,10 +469,21 @@ function _Remapper(
FT = Spaces.undertype(space)
ArrayType = ClimaComms.array_type(space)

cpu_space = if ClimaComms.device(space) isa ClimaComms.AbstractCPUDevice
space
else
device = ClimaComms.device(space)
cpu_space = if device isa ClimaComms.AbstractCPUDevice
space
else
Adapt.adapt(Array, space)
end
end

vert_interpolation_weights =
ArrayType(vertical_interpolation_weights(space, target_zcoords))
ArrayType(vertical_interpolation_weights(cpu_space, target_zcoords))
vert_bounding_indices =
ArrayType(vertical_bounding_indices(space, target_zcoords))
ArrayType(vertical_bounding_indices(cpu_space, target_zcoords))

local_interpolated_values =
ArrayType(zeros(FT, (length(target_zcoords), buffer_length)))
Expand Down
2 changes: 1 addition & 1 deletion src/Topologies/Topologies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ClimaComms, Adapt
import ..ClimaCore
import ..Utilities: Cache, cart_ind, linear_ind
import ..Geometry
import ..Domains: Domains, coordinate_type
import ..Domains: Domains, coordinate_type, unique_boundary_names
import ..Meshes: Meshes, domain, coordinates
import ..DataLayouts
import ..DataLayouts: slab_index
Expand Down
23 changes: 11 additions & 12 deletions src/Topologies/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@ end

Adapt.@adapt_structure IntervalTopology

## gpu
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
boundaries::B
end

ClimaComms.context(topology::DeviceIntervalTopology) = DeviceSideContext()
ClimaComms.device(topology::DeviceIntervalTopology) = DeviceSideDevice()
ClimaComms.context(topology::IntervalTopology) = topology.context
ClimaComms.device(topology::IntervalTopology) =
ClimaComms.device(topology.context)

ClimaComms.device(topology::IntervalTopology) = topology.context.device
ClimaComms.array_type(topology::IntervalTopology) =
ClimaComms.array_type(topology.context.device)

Expand All @@ -46,12 +41,16 @@ function _IntervalTopology(
)
# currently only support SingletonCommsContext
@assert context isa ClimaComms.SingletonCommsContext
if Domains.isperiodic(mesh.domain)
domain = mesh.domain
if Domains.isperiodic(domain)
boundaries = NamedTuple()
elseif mesh.domain.boundary_names[1] == mesh.domain.boundary_names[2]
boundaries = NamedTuple{(mesh.domain.boundary_names[1],)}(1)
else
boundaries = NamedTuple{mesh.domain.boundary_names}((1, 2))
bn = Domains.boundary_names(domain)
boundaries = if bn[1] == bn[2]
NamedTuple{(bn[1],)}(1)
else
NamedTuple{bn}((1, 2))
end
end
IntervalTopology(context, mesh, boundaries)
end
Expand Down
13 changes: 7 additions & 6 deletions src/Topologies/topology2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ function _Topology2D(
# 5) faces
boundaries = NamedTuple(
boundary_name => Tuple{Int, Int}[] for
boundary_name in Meshes.boundary_names(mesh)
boundary_name in Meshes.unique_boundary_names(mesh)
)
interior_faces = Tuple{Int, Int, Int, Int, Bool}[]
ghost_faces = Tuple{Int, Int, Int, Int, Bool}[]
Expand Down Expand Up @@ -772,12 +772,13 @@ end


neighbors(topology::Topology2D) = topology.neighbor_pids
boundary_names(topology::Topology2D) = keys(topology.boundaries)
boundary_tags(topology::Topology2D) = NamedTuple{boundary_names(topology)}(
ntuple(i -> i, length(topology.boundaries)),
)
unique_boundary_names(topology::Topology2D) = keys(topology.boundaries)
boundary_tags(topology::Topology2D) =
NamedTuple{unique_boundary_names(topology)}(
ntuple(i -> i, length(topology.boundaries)),
)
boundary_tag(topology::Topology2D, boundary_name::Symbol) =
findfirst(==(boundary_name), boundary_names(topology))
findfirst(==(boundary_name), unique_boundary_names(topology))

boundary_faces(topology::Topology2D, boundary) = topology.boundaries[boundary]

Expand Down
20 changes: 10 additions & 10 deletions test/Operators/finitedifference/unit_columnwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ import ClimaCore.MatrixFields
import ClimaCore.Spaces
import ClimaCore.Fields

Operators.fd_shmem_is_supported(bc::Base.Broadcast.Broadcasted) = false
ClimaCore.Operators.use_fd_shmem() = false
# Operators.fd_shmem_is_supported(bc::Base.Broadcast.Broadcasted) = false
# ClimaCore.Operators.use_fd_shmem() = false
# The existing implementation limits our ability to apply
# the same expressions from within kernels
ClimaComms.device(topology::Topologies.DeviceIntervalTopology) =
ClimaComms.CUDADevice()
Fields.error_mismatched_spaces(::Type, ::Type) = nothing # causes unsupported dynamic function invocation
# Fields.error_mismatched_spaces(::Type, ::Type) = nothing # causes unsupported dynamic function invocation

const C1 = Geometry.Covariant1Vector
const C2 = Geometry.Covariant2Vector
Expand Down Expand Up @@ -91,12 +89,14 @@ Base.Broadcast.broadcastable(x::RayleighSponge) = tuple(x)

function rayleigh_sponge_tendency_uₕ(ᶜuₕ, s)
s isa Nothing && return NullBroadcasted()
(; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ))
zmax = z_max(axes(ᶠz))
ᶜz = Fields.coordinate_field(axes(ᶜuₕ)).z
ᶠz = Fields.coordinate_field(Spaces.face_space(axes(ᶜuₕ))).z
# (; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ))
zmax = Spaces.z_max(axes(ᶠz))
return @. lazy(-β_rayleigh_uₕ(s, ᶜz, zmax) * ᶜuₕ)
end

function compute_kinetic(uₕ::Fields.Field, uᵥ::Fields.Field)
function compute_kinetic(uₕ, uᵥ)
@assert eltype(uₕ) <: Union{C1, C2, C12}
@assert eltype(uᵥ) <: C3
FT = Spaces.undertype(axes(uₕ))
Expand Down Expand Up @@ -167,9 +167,9 @@ function ᶜimplicit_tendency_bc(ᶜY, ᶠY, p, t)
ᶠu³ = @. lazy(ᶠuₕ³ + CT3(ᶠu₃))
tend_ρ_1 = @. lazy(ᶜdivᵥ(ᶠwinterp(ᶜJ, ᶜρ) * ᶠuₕ³))
tend_ρe_tot_1 = vertical_transport(ᶜρ, ᶠu³, ᶜh_tot, dt, Val(:none))
ᶜuₕ₀ = (zero(eltype(ᶜuₕ)),)
ᶜuₕ₀ = rayleigh_sponge_tendency_uₕ(ᶜuₕ, rayleigh_sponge)

return @. lazy(ᶜtendencies(-tend_ρ_1, - ᶜuₕ₀, tend_ρe_tot_1))
return @. lazy(ᶜtendencies(-tend_ρ_1, ᶜuₕ₀, tend_ρe_tot_1))
end

function ᶠimplicit_tendency_bc(ᶜY, ᶠY, p, t)
Expand Down
4 changes: 2 additions & 2 deletions test/Spaces/unit_spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ end

@static if on_gpu
adapted_space = Adapt.adapt(CUDA.KernelAdaptor(), c_space)
@test ClimaComms.context(adapted_space) == DeviceSideContext()
@test ClimaComms.device(adapted_space) == DeviceSideDevice()
@test ClimaComms.context(adapted_space) == ClimaComms.context(c_space)
@test ClimaComms.device(adapted_space) == ClimaComms.device(c_space)

adapted_hspace = Adapt.adapt(CUDA.KernelAdaptor(), hspace)
@test ClimaComms.context(adapted_hspace) == DeviceSideContext()
Expand Down
Loading