diff --git a/src/cache/precomputed_quantities.jl b/src/cache/precomputed_quantities.jl index f713f3fa34..798c254cf3 100644 --- a/src/cache/precomputed_quantities.jl +++ b/src/cache/precomputed_quantities.jl @@ -256,6 +256,7 @@ end # Interpolates the third contravariant component of Y.c.uₕ to cell faces. function compute_ᶠuₕ³(ᶜuₕ, ᶜρ) + assert_eltype(ᶜuₕ, Geometry.Covariant12Vector) ᶜJ = Fields.local_geometry_field(ᶜρ).J return @. lazy(ᶠwinterp(ᶜρ * ᶜJ, CT3(ᶜuₕ))) end @@ -281,10 +282,15 @@ function set_velocity_at_surface!(Y, ᶠuₕ³, turbconv_model) end function surface_velocity(ᶠu₃, ᶠuₕ³) + assert_eltype(ᶠu₃, Geometry.Covariant3Vector) + assert_eltype(ᶠuₕ³, Geometry.Contravariant3Vector) + ᶠlg = Fields.local_geometry_field(axes(ᶠu₃)) sfc_u₃ = Fields.level(ᶠu₃.components.data.:1, half) sfc_uₕ³ = Fields.level(ᶠuₕ³.components.data.:1, half) sfc_g³³ = g³³_field(sfc_u₃) - return @. lazy(-sfc_uₕ³ / sfc_g³³) # u³ = uₕ³ + w³ = uₕ³ + w₃ * g³³ + w₃ = @. lazy(-C3(sfc_uₕ³ / sfc_g³³, ᶠlg)) # u³ = uₕ³ + w³ = uₕ³ + w₃ * g³³ + assert_eltype(w₃, Geometry.Covariant3Vector) + return w₃ end """ @@ -313,6 +319,9 @@ end # This is used to set the grid-scale velocity quantities ᶜu, ᶠu³, ᶜK based on # ᶠu₃, and it is also used to set the SGS quantities based on ᶠu₃⁰ and ᶠu₃ʲ. function set_velocity_quantities!(ᶜu, ᶠu³, ᶜK, ᶠu₃, ᶜuₕ, ᶠuₕ³) + assert_eltype(ᶠu₃, Geometry.Covariant3Vector) + assert_eltype(ᶠuₕ³, Geometry.Contravariant3Vector) + assert_eltype(ᶠu³, Geometry.Contravariant3Vector) @. ᶜu = C123(ᶜuₕ) + ᶜinterp(C123(ᶠu₃)) @. ᶠu³ = ᶠuₕ³ + CT3(ᶠu₃) ᶜK .= compute_kinetic(ᶜuₕ, ᶠu₃) diff --git a/src/utils/utilities.jl b/src/utils/utilities.jl index dba3faf059..bc158afe1f 100644 --- a/src/utils/utilities.jl +++ b/src/utils/utilities.jl @@ -493,3 +493,26 @@ function issphere(space) return Meshes.domain(Spaces.topology(Spaces.horizontal_space(space))) isa Domains.SphereDomain end + +import ClimaCore.DataLayouts + +""" + assert_eltype(x, ::Type{T}) where {T} + +Assert that the eltype of `x` is of type `T` +""" +function assert_eltype end + +assert_eltype(bc::Base.AbstractBroadcasted, ::Type{T}) where {T} = + assert_eltype(eltype(bc), T) +assert_eltype(f::Fields.Field, ::Type{T}) where {T} = + assert_eltype(Fields.field_values(f), T) +assert_eltype(data::DataLayouts.AbstractData, ::Type{T}) where {T} = + assert_eltype(eltype(data), T) + +function assert_eltype(::Type{S}, ::Type{T}) where {S, T} + if !(S <: T) + error("Type $S should be a subtype of $T") + end + return nothing +end