Skip to content

Commit e020b17

Browse files
Updates to masked spaces (#2235)
1 parent b21bd9f commit e020b17

File tree

4 files changed

+30
-5
lines changed

4 files changed

+30
-5
lines changed

src/DataLayouts/DataLayouts.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,8 +2330,17 @@ function ColumnMask(
23302330
@assert Nq isa Integer
23312331
@assert Nh isa Integer
23322332
T = horizontal_layout_type
2333-
is_active = replace_basetype(T{FT, Nq}(Array{FT}, Nh), Bool)
2334-
Nijh = Nq * Nq * Nh
2333+
is_active = replace_basetype(T{FT, Nq}(DA{FT}, Nh), Bool)
2334+
parent(is_active) .= true
2335+
return IJHMask(is_active)
2336+
end
2337+
2338+
union_all_type(::Type{T}) where {T} = T.name.wrapper
2339+
2340+
function IJHMask(is_active::Union{IJFH, IJHF})
2341+
DA = union_all_type(typeof(parent(is_active)))
2342+
(Ni, Nj, _, _, Nh) = size(is_active)
2343+
Nijh = Ni * Nj * Nh
23352344
i_map = zeros(Int, Nijh)
23362345
j_map = zeros(Int, Nijh)
23372346
h_map = zeros(Int, Nijh)

src/Fields/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function Base.copy(
9292
throw(BroadcastInferenceError(bc))
9393
end
9494
# We can trust it and defer to the simpler `copyto!`
95-
return copyto!(similar(bc, ElType), bc, DataLayouts.NoMask())
95+
return copyto!(similar(bc, ElType), bc, Spaces.get_mask(axes(bc)))
9696
end
9797

9898
Base.@propagate_inbounds function slab(

src/Spaces/spectralelement.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ function Base.show(io::IO, space::AbstractSpectralElementSpace)
2020
indent = get(io, :indent, 0)
2121
iio = IOContext(io, :indent => indent + 2)
2222
println(io, nameof(typeof(space)), ":")
23+
if get_mask(space) isa DataLayouts.NoMask
24+
println(iio, " "^(indent + 2), "mask_enabled: false")
25+
else
26+
println(iio, " "^(indent + 2), "mask_enabled: true")
27+
end
2328
if hasfield(typeof(grid(space)), :topology)
2429
# some reduced spaces (like slab space) do not have topology
2530
print(iio, " "^(indent + 2), "context: ")

test/Spaces/unit_spaces.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@ import ClimaCore:
2323
DeviceSideDevice
2424

2525
using ClimaCore.CommonSpaces
26+
using ClimaCore.Utilities.Cache
2627
import ClimaCore.DataLayouts: IJFH, VF, slab_index
2728

2829
on_gpu = ClimaComms.device() isa ClimaComms.CUDADevice
2930

3031
@testset "2D spaces with mask" begin
32+
# We need to test a fresh instance of the spaces, since
33+
# masked spaces include data set by users.
34+
Cache.clean_cache!()
3135
FT = Float64
3236
context = ClimaComms.context()
3337
x_max = FT(1)
@@ -57,6 +61,7 @@ on_gpu = ClimaComms.device() isa ClimaComms.CUDADevice
5761
hspace = Spaces.SpectralElementSpace2D(htopology, quad; enable_mask = true)
5862
mask = Spaces.get_mask(hspace)
5963
@test mask isa DataLayouts.IJHMask
64+
@test all(x -> x == true, parent(mask.is_active)) # test that default is true
6065
Spaces.set_mask!(hspace) do coords
6166
coords.x > 0.5
6267
end
@@ -71,6 +76,9 @@ on_gpu = ClimaComms.device() isa ClimaComms.CUDADevice
7176
@. f = 1 + ᶜx * 0 # tests copyto!
7277
@test count(iszero, parent(f)) == 2
7378

79+
fbc = @. 1 + ᶜx * 0 # tests copy
80+
@test Spaces.get_mask(axes(fbc)) isa DataLayouts.IJHMask
81+
7482
FT = Float64
7583
ᶜspace = ExtrudedCubedSphereSpace(
7684
FT;
@@ -92,8 +100,7 @@ on_gpu = ClimaComms.device() isa ClimaComms.CUDADevice
92100
# Test that mask-field assignment works:
93101
# TODO: we should make this easier
94102
is_active = similar(mask.is_active)
95-
_is_active = Fields.field_values(float.(hᶠcoords.lat .> 0.5))
96-
is_active .= DataLayouts.replace_basetype(_is_active, Bool)
103+
parent(is_active) .= parent(hᶠcoords.lat) .> 0.5
97104
Spaces.set_mask!(ᶜspace, is_active)
98105

99106
Spaces.set_mask!(ᶜspace) do coords
@@ -131,6 +138,8 @@ on_gpu = ClimaComms.device() isa ClimaComms.CUDADevice
131138
ᶠspace_no_mask = Spaces.face_space(ᶜspace_no_mask)
132139
ᶠcoords_no_mask = Fields.coordinate_field(ᶠspace_no_mask)
133140
c_no_mask = Fields.Field(FT, ᶜspace_no_mask)
141+
@test_throws ErrorException("Broacasted spaces are not the same.") @. c_no_mask +
142+
ᶜf
134143
ᶠf_no_mask = Fields.Field(FT, ᶠspace_no_mask)
135144
if ClimaComms.device(ᶜspace_no_mask) isa ClimaComms.CUDADevice
136145
@. c_no_mask = div(Geometry.WVector(foo(ᶠf_no_mask, ᶠcoords_no_mask)))
@@ -162,6 +171,7 @@ end
162171

163172
expected_repr = """
164173
SpectralElementSpace1D:
174+
mask_enabled: false
165175
context: SingletonCommsContext using $(nameof(typeof(device)))
166176
mesh: 1-element IntervalMesh of IntervalDomain: x ∈ [-3.0,5.0] (periodic)
167177
quadrature: 4-point Gauss-Legendre-Lobatto quadrature"""
@@ -348,6 +358,7 @@ end
348358
space = Spaces.SpectralElementSpace2D(grid_topology, quad)
349359
@test repr(space) == """
350360
SpectralElementSpace2D:
361+
mask_enabled: false
351362
context: SingletonCommsContext using CPUSingleThreaded
352363
mesh: 1×1-element RectilinearMesh of RectangleDomain: x ∈ [-3.0,5.0] (periodic) × y ∈ [-2.0,8.0] (:south, :north)
353364
quadrature: 4-point Gauss-Legendre-Lobatto quadrature"""

0 commit comments

Comments
 (0)