Skip to content

Commit b47dffb

Browse files
Fix fill with mask (#2285)
1 parent 009a497 commit b47dffb

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

src/Fields/Fields.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import ..Grids: ColumnIndex, local_geometry_type
2323
import ..Spaces: Spaces, AbstractSpace, AbstractPointSpace, cuda_synchronize
2424
import ..Spaces: nlevels, ncolumns
2525
import ..Spaces: get_mask, set_mask!
26+
import ..DataLayouts: AbstractMask
2627
import ..Geometry: Geometry, Cartesian12Vector
2728
import ..Utilities: PlusHalf, half
2829

@@ -285,20 +286,25 @@ Base.deepcopy_internal(field::Field, stackdict::IdDict) =
285286
function Base.copyto!(
286287
dest::Field{V, M},
287288
src::Field{V, M},
288-
mask = DataLayouts.NoMask,
289+
mask = DataLayouts.NoMask(),
289290
) where {V, M}
290291
@assert axes(dest) == axes(src)
291292
copyto!(field_values(dest), field_values(src), mask)
292293
return dest
293294
end
294295

295296
"""
296-
fill!(field::Field, value)
297+
fill!(field::Field, value, mask = get_mask(axes(field)))
297298
298-
Fill `field` with `value`.
299+
Fill `field` with `value`. The mask is extracted from the field's space,
300+
and `fill!` is only applied where the `mask` is true.
299301
"""
300-
function Base.fill!(field::Field, value)
301-
fill!(field_values(field), value)
302+
function Base.fill!(
303+
field::Field,
304+
value,
305+
mask::AbstractMask = get_mask(axes(field)),
306+
)
307+
fill!(field_values(field), value, mask)
302308
return field
303309
end
304310
"""

test/Spaces/unit_spaces.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#=
2-
julia --project
2+
julia --project=.buildkite
33
using Revise; include(joinpath("test", "Spaces", "unit_spaces.jl"))
44
=#
55
using Test
@@ -70,8 +70,9 @@ on_gpu = ClimaComms.device() isa ClimaComms.CUDADevice
7070

7171
f = Fields.Field(FT, hspace)
7272
fill!(parent(f), 0)
73-
@. f = 1 # tests fill!
73+
fill!(f, 1)
7474
@test count(iszero, parent(f)) == 2
75+
@test count(x -> x == 1, parent(f)) == 2
7576
ᶜx = Fields.coordinate_field(hspace).x
7677
@. f = 1 + ᶜx * 0 # tests copyto!
7778
@test count(iszero, parent(f)) == 2
@@ -109,7 +110,7 @@ on_gpu = ClimaComms.device() isa ClimaComms.CUDADevice
109110
@test count(parent(mask.is_active)) == 4640
110111
@test length(parent(mask.is_active)) == 9600
111112
ᶜf = zeros(ᶜspace)
112-
@. ᶜf = 1 # tests fill!
113+
fill!(ᶜf, 1)
113114
@test count(x -> x == 1, parent(ᶜf)) == 4640 * Spaces.nlevels(axes(ᶜf))
114115
@test length(parent(ᶜf)) == 9600 * Spaces.nlevels(axes(ᶜf))
115116
ᶜz = Fields.coordinate_field(ᶜspace).z

0 commit comments

Comments
 (0)