Skip to content

Commit 2bce288

Browse files
Specialize field constructor for boolean fields (#2239)
1 parent 575f067 commit 2bce288

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ ClimaCore.jl Release Notes
44
main
55
-------
66

7+
- `Fields.Field(Bool, ::AbstractSpace)` is now supported. PR [2239](https://github.com/CliMA/ClimaCore.jl/pull/2239).
8+
79
- `SpectralElementSpace2D` constructors now support nodal masks. PR [2201](https://github.com/CliMA/ClimaCore.jl/pull/2201). See its documentation [here](https://clima.github.io/ClimaCore.jl/dev/masks). Note that it does not yet support restarts.
810

911
- Added support for InputOutput with PointSpaces

src/Fields/Fields.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ Field(values::V, space::S) where {V <: AbstractData, S <: AbstractSpace} =
5353
Field(::Type{T}, space::S) where {T, S <: AbstractSpace} =
5454
Field(similar(Spaces.coordinates_data(space), T), space)
5555

56+
function Field(::Type{Bool}, space::S) where {S <: AbstractSpace}
57+
FT = Spaces.undertype(space)
58+
data = similar(Spaces.coordinates_data(space), FT)
59+
bool_data = DataLayouts.replace_basetype(data, Bool)
60+
return Field(bool_data, space)
61+
end
62+
5663
local_geometry_type(::Field{V, S}) where {V, S} = local_geometry_type(S)
5764

5865
ClimaComms.context(field::Field) = ClimaComms.context(axes(field))

test/Fields/unit_field.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,34 @@ end
11181118
@. f += 1
11191119
end
11201120

1121+
@testset "Boolean fields" begin
1122+
FT = Float32
1123+
space = ExtrudedCubedSphereSpace(;
1124+
z_elem = 10,
1125+
z_min = 0,
1126+
z_max = 1,
1127+
radius = 10,
1128+
h_elem = 10,
1129+
n_quad_points = 4,
1130+
staggering = Grids.CellCenter(),
1131+
)
1132+
bf = Fields.Field(Bool, space)
1133+
@. bf = true
1134+
@test all(x -> x == true, Array(parent(bf)))
1135+
@. bf = 1
1136+
@test all(x -> x == true, Array(parent(bf)))
1137+
@. bf = 0
1138+
@test all(x -> x == false, Array(parent(bf)))
1139+
@. bf = 0 + bf # test copyto!(bf, ::Braodcasted)
1140+
@test all(x -> x == false, Array(parent(bf)))
1141+
if ClimaComms.device() isa ClimaComms.AbstractCPUDevice
1142+
@test_throws InexactError begin
1143+
@. bf = 2.0 # no error on gpu
1144+
end
1145+
end
1146+
bf_new = @. bf # test copy()
1147+
end
1148+
11211149
include("unit_field_multi_broadcast_fusion.jl")
11221150

11231151
nothing

0 commit comments

Comments
 (0)