diff --git a/docs/src/APIs/domains_api.md b/docs/src/APIs/domains_api.md index 3d90aab192..e01995ba56 100644 --- a/docs/src/APIs/domains_api.md +++ b/docs/src/APIs/domains_api.md @@ -17,4 +17,5 @@ Domains.SphereDomain ```@docs Domains.boundary_names +Domains.unique_boundary_names ``` diff --git a/ext/cuda/adapt.jl b/ext/cuda/adapt.jl index 14c0b645db..6f5357a340 100644 --- a/ext/cuda/adapt.jl +++ b/ext/cuda/adapt.jl @@ -42,7 +42,14 @@ Adapt.adapt_structure(to::CUDA.KernelAdaptor, space::Spaces.PointSpace) = Adapt.adapt_structure( to::CUDA.KernelAdaptor, topology::Topologies.IntervalTopology, -) = Topologies.DeviceIntervalTopology(topology.boundaries) +) = Topologies.IntervalTopology( + Adapt.adapt( + to, + ClimaComms.SingletonCommsContext(ClimaComms.device(topology.context)), + ), + Adapt.adapt(to, topology.mesh), + Adapt.adapt(to, topology.boundaries), +) Adapt.adapt_structure( to::CUDA.KernelAdaptor, diff --git a/src/Domains/Domains.jl b/src/Domains/Domains.jl index 4bd1c5e972..e54fea27b1 100644 --- a/src/Domains/Domains.jl +++ b/src/Domains/Domains.jl @@ -22,22 +22,27 @@ float_type(domain::AbstractDomain) = float_type(coordinate_type(domain)) """ boundary_names(obj::Union{AbstractDomain, AbstractMesh, AbstractTopology}) -A tuple or vector of unique boundary names of a spatial domain. +The boundary names passed to the IntervalDomain (a tuple, or `nothing`). """ function boundary_names end -struct IntervalDomain{CT, B} <: AbstractDomain where { - CT <: Geometry.Abstract1DPoint{FT}, - B <: BCTagType, -} where {FT} +""" + unique_boundary_names(obj::Union{AbstractDomain, AbstractMesh, AbstractTopology}) + +A tuple or vector of unique boundary names of a spatial domain. +""" +function unique_boundary_names end + +struct IntervalDomain{CT, B} <: + AbstractDomain where {CT <: Geometry.Abstract1DPoint{FT}, B} where {FT} coord_min::CT coord_max::CT - boundary_names::B end -isperiodic(domain::IntervalDomain) = isnothing(domain.boundary_names) -boundary_names(domain::IntervalDomain) = - isperiodic(domain) ? () : unique(domain.boundary_names) +isperiodic(::IntervalDomain{CT, B}) where {CT, B} = B == nothing +unique_boundary_names(domain::IntervalDomain{CT, B}) where {CT, B} = + isperiodic(domain) ? Symbol[] : unique(B) +boundary_names(::IntervalDomain{CT, B}) where {CT, B} = B """ IntervalDomain(coord⁻, coord⁺; periodic=true) @@ -60,7 +65,13 @@ function IntervalDomain( ), ) end - IntervalDomain(promote(coord_min, coord_max)..., boundary_names) + c = promote(coord_min, coord_max) + boundary_names = if isnothing(boundary_names) + boundary_names + else + Tuple(boundary_names) + end + IntervalDomain{eltype(c), boundary_names}(c...) end IntervalDomain(coords::IntervalSets.ClosedInterval; kwargs...) = IntervalDomain(coords.left, coords.right; kwargs...) @@ -95,7 +106,7 @@ function print_interval(io::IO, domain::IntervalDomain{CT}) where {CT} if isperiodic(domain) print(io, "(periodic)") else - print(io, domain.boundary_names) + print(io, boundary_names(domain)) end end function Base.show(io::IO, domain::IntervalDomain) @@ -111,10 +122,10 @@ end Base.:*(interval1::IntervalDomain, interval2::IntervalDomain) = RectangleDomain(interval1, interval2) -boundary_names(domain::RectangleDomain) = unique( +unique_boundary_names(domain::RectangleDomain) = unique( Symbol[ - boundary_names(domain.interval1)..., - boundary_names(domain.interval2)..., + unique_boundary_names(domain.interval1)..., + unique_boundary_names(domain.interval2)..., ], )::Vector{Symbol} @@ -171,6 +182,7 @@ Base.show(io::IO, domain::SphereDomain) = print(io, nameof(typeof(domain)), ": radius = ", domain.radius) boundary_names(::SphereDomain) = () +unique_boundary_names(::SphereDomain) = Symbol[] coordinate_type(::SphereDomain{FT}) where {FT} = Geometry.Cartesian123Point{FT} end # module diff --git a/src/Grids/finitedifference.jl b/src/Grids/finitedifference.jl index 33535ff99a..f06544bbd9 100644 --- a/src/Grids/finitedifference.jl +++ b/src/Grids/finitedifference.jl @@ -73,7 +73,7 @@ function _FiniteDifferenceGrid(topology::Topologies.IntervalTopology) ) return FiniteDifferenceGrid( - topology, + Adapt.adapt(ArrayType, topology), global_geometry, Adapt.adapt(ArrayType, center_local_geometry), Adapt.adapt(ArrayType, face_local_geometry), diff --git a/src/InputOutput/writers.jl b/src/InputOutput/writers.jl index 8bd10fefc2..f954acac63 100644 --- a/src/InputOutput/writers.jl +++ b/src/InputOutput/writers.jl @@ -166,10 +166,10 @@ function write_new!( ) write_attribute(group, "coord_min", Geometry.component(domain.coord_min, 1)) write_attribute(group, "coord_max", Geometry.component(domain.coord_max, 1)) - !isnothing(domain.boundary_names) && write_attribute( + !isnothing(Domains.boundary_names(domain)) && write_attribute( group, "boundary_names", - [String(bname) for bname in domain.boundary_names], + [String(bname) for bname in Domains.boundary_names(domain)], ) return name end @@ -209,10 +209,15 @@ function write_new!( write_attribute(group, "faces_type", "Range") else write_attribute(group, "faces_type", "Array") + faces = if ClimaComms.device(writer.context) isa ClimaComms.AbstractCPUDevice + mesh.faces + else + Array(mesh.faces) + end write_attribute( group, "faces", - [getfield(mesh.faces[i], 1) for i in 1:length(mesh.faces)], + [getfield(faces[i], 1) for i in 1:length(faces)], ) end (; stretch) = mesh diff --git a/src/Meshes/Meshes.jl b/src/Meshes/Meshes.jl index a42585742d..b68d0894d2 100644 --- a/src/Meshes/Meshes.jl +++ b/src/Meshes/Meshes.jl @@ -14,6 +14,7 @@ import ..Domains: RectangleDomain, SphereDomain, boundary_names, + unique_boundary_names, coordinate_type import ..Geometry import SparseArrays, CubedSphere, LinearAlgebra, StaticArrays diff --git a/src/Meshes/common.jl b/src/Meshes/common.jl index 57aef72fa5..3680448a99 100644 --- a/src/Meshes/common.jl +++ b/src/Meshes/common.jl @@ -1,4 +1,5 @@ boundary_names(mesh::AbstractMesh) = boundary_names(domain(mesh)) +unique_boundary_names(mesh::AbstractMesh) = unique_boundary_names(domain(mesh)) coordinate_type(mesh::AbstractMesh) = coordinate_type(domain(mesh)) """ diff --git a/src/Meshes/interval.jl b/src/Meshes/interval.jl index 2d8a0d87c6..f3bae82e6e 100644 --- a/src/Meshes/interval.jl +++ b/src/Meshes/interval.jl @@ -26,6 +26,9 @@ struct IntervalMesh{S, I <: IntervalDomain, V <: AbstractVector, M} <: meta::M end +using Adapt +Adapt.@adapt_structure IntervalMesh + # implies isequal Base.:(==)(mesh1::IntervalMesh, mesh2::IntervalMesh) = mesh1.domain == mesh2.domain && mesh1.faces == mesh2.faces @@ -85,9 +88,9 @@ end function boundary_face_name(mesh::IntervalMesh, elem::Integer, face) if !Domains.isperiodic(mesh.domain) if elem == 1 && face == 1 - return mesh.domain.boundary_names[1] + return Domains.unique_boundary_names(mesh.domain)[1] elseif elem == nelements(mesh) && face == 2 - return mesh.domain.boundary_names[2] + return Domains.unique_boundary_names(mesh.domain)[2] end end return nothing @@ -468,7 +471,7 @@ function truncate_mesh( new_domain = IntervalDomain( z_bottom, Geometry.ZPoint{FT}(z_top), - boundary_names = trunc_domain.boundary_names, + boundary_names = Domains.boundary_names(trunc_domain), ) return IntervalMesh(new_domain, new_stretch; nelems = new_nelems) end diff --git a/src/Remapping/Remapping.jl b/src/Remapping/Remapping.jl index 801d6da803..568ed28ed0 100644 --- a/src/Remapping/Remapping.jl +++ b/src/Remapping/Remapping.jl @@ -16,6 +16,7 @@ import ..DataLayouts, ..Hypsography import ClimaCore.Utilities: half import ClimaCore.Spaces: cuda_synchronize +import Adapt using ..RecursiveApply diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index cbf4dcf9ff..a84af117e1 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -399,10 +399,16 @@ function _Remapper( ) num_dims = num_hdims else + device = ClimaComms.device(space) + cpu_space = if device isa ClimaComms.AbstractCPUDevice + space + else + Adapt.adapt(Array, space) + end vert_interpolation_weights = - ArrayType(vertical_interpolation_weights(space, target_zcoords)) + ArrayType(vertical_interpolation_weights(cpu_space, target_zcoords)) vert_bounding_indices = - ArrayType(vertical_bounding_indices(space, target_zcoords)) + ArrayType(vertical_bounding_indices(cpu_space, target_zcoords)) # We have to add one extra dimension with respect to the bitmask/local_horiz_indices # because we are going to store the values for the columns @@ -463,10 +469,21 @@ function _Remapper( FT = Spaces.undertype(space) ArrayType = ClimaComms.array_type(space) + cpu_space = if ClimaComms.device(space) isa ClimaComms.AbstractCPUDevice + space + else + device = ClimaComms.device(space) + cpu_space = if device isa ClimaComms.AbstractCPUDevice + space + else + Adapt.adapt(Array, space) + end + end + vert_interpolation_weights = - ArrayType(vertical_interpolation_weights(space, target_zcoords)) + ArrayType(vertical_interpolation_weights(cpu_space, target_zcoords)) vert_bounding_indices = - ArrayType(vertical_bounding_indices(space, target_zcoords)) + ArrayType(vertical_bounding_indices(cpu_space, target_zcoords)) local_interpolated_values = ArrayType(zeros(FT, (length(target_zcoords), buffer_length))) diff --git a/src/Topologies/Topologies.jl b/src/Topologies/Topologies.jl index 06c12d4053..97e62a3869 100644 --- a/src/Topologies/Topologies.jl +++ b/src/Topologies/Topologies.jl @@ -5,7 +5,7 @@ import ClimaComms, Adapt import ..ClimaCore import ..Utilities: Cache, cart_ind, linear_ind import ..Geometry -import ..Domains: Domains, coordinate_type +import ..Domains: Domains, coordinate_type, unique_boundary_names import ..Meshes: Meshes, domain, coordinates import ..DataLayouts import ..DataLayouts: slab_index diff --git a/src/Topologies/interval.jl b/src/Topologies/interval.jl index 72c584d70a..5b76a5b7e3 100644 --- a/src/Topologies/interval.jl +++ b/src/Topologies/interval.jl @@ -18,15 +18,10 @@ end Adapt.@adapt_structure IntervalTopology -## gpu -struct DeviceIntervalTopology{B} <: AbstractIntervalTopology - boundaries::B -end - -ClimaComms.context(topology::DeviceIntervalTopology) = DeviceSideContext() -ClimaComms.device(topology::DeviceIntervalTopology) = DeviceSideDevice() +ClimaComms.context(topology::IntervalTopology) = topology.context +ClimaComms.device(topology::IntervalTopology) = + ClimaComms.device(topology.context) -ClimaComms.device(topology::IntervalTopology) = topology.context.device ClimaComms.array_type(topology::IntervalTopology) = ClimaComms.array_type(topology.context.device) @@ -46,12 +41,16 @@ function _IntervalTopology( ) # currently only support SingletonCommsContext @assert context isa ClimaComms.SingletonCommsContext - if Domains.isperiodic(mesh.domain) + domain = mesh.domain + if Domains.isperiodic(domain) boundaries = NamedTuple() - elseif mesh.domain.boundary_names[1] == mesh.domain.boundary_names[2] - boundaries = NamedTuple{(mesh.domain.boundary_names[1],)}(1) else - boundaries = NamedTuple{mesh.domain.boundary_names}((1, 2)) + bn = Domains.boundary_names(domain) + boundaries = if bn[1] == bn[2] + NamedTuple{(bn[1],)}(1) + else + NamedTuple{bn}((1, 2)) + end end IntervalTopology(context, mesh, boundaries) end diff --git a/src/Topologies/topology2d.jl b/src/Topologies/topology2d.jl index 816e798b2e..7967e4f70b 100644 --- a/src/Topologies/topology2d.jl +++ b/src/Topologies/topology2d.jl @@ -462,7 +462,7 @@ function _Topology2D( # 5) faces boundaries = NamedTuple( boundary_name => Tuple{Int, Int}[] for - boundary_name in Meshes.boundary_names(mesh) + boundary_name in Meshes.unique_boundary_names(mesh) ) interior_faces = Tuple{Int, Int, Int, Int, Bool}[] ghost_faces = Tuple{Int, Int, Int, Int, Bool}[] @@ -772,12 +772,13 @@ end neighbors(topology::Topology2D) = topology.neighbor_pids -boundary_names(topology::Topology2D) = keys(topology.boundaries) -boundary_tags(topology::Topology2D) = NamedTuple{boundary_names(topology)}( - ntuple(i -> i, length(topology.boundaries)), -) +unique_boundary_names(topology::Topology2D) = keys(topology.boundaries) +boundary_tags(topology::Topology2D) = + NamedTuple{unique_boundary_names(topology)}( + ntuple(i -> i, length(topology.boundaries)), + ) boundary_tag(topology::Topology2D, boundary_name::Symbol) = - findfirst(==(boundary_name), boundary_names(topology)) + findfirst(==(boundary_name), unique_boundary_names(topology)) boundary_faces(topology::Topology2D, boundary) = topology.boundaries[boundary] diff --git a/test/Operators/finitedifference/unit_columnwise.jl b/test/Operators/finitedifference/unit_columnwise.jl index 4251dd6fc4..ce0d4a3ad1 100644 --- a/test/Operators/finitedifference/unit_columnwise.jl +++ b/test/Operators/finitedifference/unit_columnwise.jl @@ -30,13 +30,11 @@ import ClimaCore.MatrixFields import ClimaCore.Spaces import ClimaCore.Fields -Operators.fd_shmem_is_supported(bc::Base.Broadcast.Broadcasted) = false -ClimaCore.Operators.use_fd_shmem() = false +# Operators.fd_shmem_is_supported(bc::Base.Broadcast.Broadcasted) = false +# ClimaCore.Operators.use_fd_shmem() = false # The existing implementation limits our ability to apply # the same expressions from within kernels -ClimaComms.device(topology::Topologies.DeviceIntervalTopology) = - ClimaComms.CUDADevice() -Fields.error_mismatched_spaces(::Type, ::Type) = nothing # causes unsupported dynamic function invocation +# Fields.error_mismatched_spaces(::Type, ::Type) = nothing # causes unsupported dynamic function invocation const C1 = Geometry.Covariant1Vector const C2 = Geometry.Covariant2Vector @@ -91,12 +89,14 @@ Base.Broadcast.broadcastable(x::RayleighSponge) = tuple(x) function rayleigh_sponge_tendency_uₕ(ᶜuₕ, s) s isa Nothing && return NullBroadcasted() - (; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ)) - zmax = z_max(axes(ᶠz)) + ᶜz = Fields.coordinate_field(axes(ᶜuₕ)).z + ᶠz = Fields.coordinate_field(Spaces.face_space(axes(ᶜuₕ))).z + # (; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ)) + zmax = Spaces.z_max(axes(ᶠz)) return @. lazy(-β_rayleigh_uₕ(s, ᶜz, zmax) * ᶜuₕ) end -function compute_kinetic(uₕ::Fields.Field, uᵥ::Fields.Field) +function compute_kinetic(uₕ, uᵥ) @assert eltype(uₕ) <: Union{C1, C2, C12} @assert eltype(uᵥ) <: C3 FT = Spaces.undertype(axes(uₕ)) @@ -167,9 +167,9 @@ function ᶜimplicit_tendency_bc(ᶜY, ᶠY, p, t) ᶠu³ = @. lazy(ᶠuₕ³ + CT3(ᶠu₃)) tend_ρ_1 = @. lazy(ᶜdivᵥ(ᶠwinterp(ᶜJ, ᶜρ) * ᶠuₕ³)) tend_ρe_tot_1 = vertical_transport(ᶜρ, ᶠu³, ᶜh_tot, dt, Val(:none)) - ᶜuₕ₀ = (zero(eltype(ᶜuₕ)),) + ᶜuₕ₀ = rayleigh_sponge_tendency_uₕ(ᶜuₕ, rayleigh_sponge) - return @. lazy(ᶜtendencies(-tend_ρ_1, - ᶜuₕ₀, tend_ρe_tot_1)) + return @. lazy(ᶜtendencies(-tend_ρ_1, ᶜuₕ₀, tend_ρe_tot_1)) end function ᶠimplicit_tendency_bc(ᶜY, ᶠY, p, t) diff --git a/test/Spaces/unit_spaces.jl b/test/Spaces/unit_spaces.jl index 7f79b6e6bf..c2a2f73386 100644 --- a/test/Spaces/unit_spaces.jl +++ b/test/Spaces/unit_spaces.jl @@ -330,8 +330,8 @@ end @static if on_gpu adapted_space = Adapt.adapt(CUDA.KernelAdaptor(), c_space) - @test ClimaComms.context(adapted_space) == DeviceSideContext() - @test ClimaComms.device(adapted_space) == DeviceSideDevice() + @test ClimaComms.context(adapted_space) == ClimaComms.context(c_space) + @test ClimaComms.device(adapted_space) == ClimaComms.device(c_space) adapted_hspace = Adapt.adapt(CUDA.KernelAdaptor(), hspace) @test ClimaComms.context(adapted_hspace) == DeviceSideContext()