Skip to content

Commit 8f033c7

Browse files
Make dss and weighted_dss no-ops for empty fields.
1 parent 0167e6f commit 8f033c7

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

src/Spaces/dss.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ function create_dss_buffer(
6363
convert_to_array = DA isa Array ? false : true
6464
(_, _, _, Nv, nelems) = Base.size(data)
6565
Np = Spaces.nperimeter(perimeter)
66-
Nf = cld(length(parent(data)), (Nij * Nij * Nv * nelems))
66+
Nf =
67+
length(parent(data)) == 0 ? 0 :
68+
cld(length(parent(data)), (Nij * Nij * Nv * nelems))
6769
nfacedof = Nij - 2
6870
T = eltype(parent(data))
6971
TS = _transformed_type(data, local_geometry, local_weights, DA) # extract transformed type
@@ -269,6 +271,7 @@ function weighted_dss_start!(
269271
hspace::SpectralElementSpace2D{<:Topology2D},
270272
dss_buffer::DSSBuffer,
271273
)
274+
length(parent(data)) == 0 && return nothing
272275
device = ClimaComms.device(hspace.topology)
273276
dss_transform!(
274277
device,
@@ -335,6 +338,7 @@ function weighted_dss_internal!(
335338
hspace::AbstractSpectralElementSpace,
336339
dss_buffer::Union{DSSBuffer, Nothing},
337340
)
341+
length(parent(data)) == 0 && return nothing
338342
if hspace isa SpectralElementSpace1D
339343
dss_1d!(
340344
hspace.topology,
@@ -413,6 +417,7 @@ function weighted_dss_ghost!(
413417
hspace::SpectralElementSpace2D{<:Topology2D},
414418
dss_buffer::DSSBuffer,
415419
)
420+
length(parent(data)) == 0 && return data
416421
device = ClimaComms.device(hspace.topology)
417422
ClimaComms.finish(dss_buffer.graph_context)
418423
load_from_recv_buffer!(device, dss_buffer)
@@ -1002,6 +1007,7 @@ end
10021007
Computed unweighted/pure DSS of `data`.
10031008
"""
10041009
function dss!(data, topology, quadrature_style)
1010+
length(parent(data)) == 0 && return nothing
10051011
device = ClimaComms.device(topology)
10061012
perimeter = Perimeter2D(Quadratures.degrees_of_freedom(quadrature_style))
10071013
# create dss buffer

test/Spaces/ddss1.jl

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@ using Logging
22
using Test
33

44
import ClimaCore:
5-
Domains, Fields, Geometry, Meshes, Operators, Spaces, Topologies
5+
Domains,
6+
Fields,
7+
Geometry,
8+
Meshes,
9+
Operators,
10+
Spaces,
11+
Topologies,
12+
DataLayouts
613

714
using ClimaComms
815
const device = ClimaComms.device()
@@ -39,6 +46,9 @@ function distributed_space(
3946
return (space, context)
4047
end
4148

49+
init_state_scalar(local_geometry, p) = (; ρ = 1.0)
50+
init_state_vector(local_geometry, p) = Geometry.Covariant12Vector(1.0, -1.0)
51+
4252
#=
4353
_
4454
|1|
@@ -61,41 +71,54 @@ end
6171
@test Topologies.local_neighboring_elements(space.topology, 3) == [2, 4]
6272
@test Topologies.local_neighboring_elements(space.topology, 4) == [1, 3]
6373

64-
init_state(local_geometry, p) == 1.0)
65-
y0 = init_state.(Fields.local_geometry_field(space), Ref(nothing))
74+
y0 = init_state_scalar.(Fields.local_geometry_field(space), Ref(nothing))
6675
nel = Topologies.nlocalelems(Spaces.topology(space))
6776
yarr = parent(y0)
6877
yarr .= reshape(1:(Nq * Nq * nel), (Nq, Nq, 1, nel))
6978

70-
dss2_buffer = Spaces.create_dss_buffer(y0)
71-
Spaces.weighted_dss!(y0, dss2_buffer) # DSS2
79+
dss_buffer = Spaces.create_dss_buffer(y0)
80+
Spaces.weighted_dss!(y0, dss_buffer) # DSS2
7281
#! format: off
7382
@test Array(yarr[:]) == [18.5, 5.0, 9.5, 18.5, 5.0, 9.5, 18.5, 5.0, 9.5, 9.5,
7483
14.0, 18.5, 9.5, 14.0, 18.5, 9.5, 14.0, 18.5, 18.5,
7584
23.0, 27.5, 18.5, 23.0, 27.5, 18.5, 23.0, 27.5, 27.5,
7685
32.0, 18.5, 27.5, 32.0, 18.5, 27.5, 32.0, 18.5]
7786
#! format: on
7887

79-
p = @allocated Spaces.weighted_dss!(y0, dss2_buffer)
88+
p = @allocated Spaces.weighted_dss!(y0, dss_buffer)
8089
@show p
8190
#=
8291
@test p == 0
8392
=#
8493
end
8594

86-
@testset "4x1 element mesh on 2 processes - vector field" begin
95+
@testset "test if dss is no-op on an empty field" begin
96+
Nq = 3
97+
space, comms_ctx = distributed_space((4, 1), (true, true), (Nq, 1, 1))
98+
y0 = init_state_scalar.(Fields.local_geometry_field(space), Ref(nothing))
99+
100+
dims = (Nq, Nq, 0, 4)
101+
array = similar(parent(y0), dims)
102+
data = DataLayouts.rebuild(Fields.field_values(y0), array)
103+
space = axes(y0)
104+
empty_field = similar(y0, Tuple{})
105+
dss_buffer = Spaces.create_dss_buffer(empty_field)
106+
@test empty_field == Spaces.weighted_dss!(empty_field)
107+
end
108+
109+
110+
@testset "4x1 element mesh on 1 process - vector field" begin
87111
Nq = 3
88112
space, comms_ctx = distributed_space((4, 1), (true, true), (Nq, 1, 2))
89-
init_state(local_geometry, p) = Geometry.Covariant12Vector(1.0, -1.0)
90-
y0 = init_state.(Fields.local_geometry_field(space), Ref(nothing))
113+
y0 = init_state_vector.(Fields.local_geometry_field(space), Ref(nothing))
91114
yx = copy(y0)
92115

93-
dss2_buffer = Spaces.create_dss_buffer(y0)
94-
Spaces.weighted_dss!(y0, dss2_buffer)
116+
dss_buffer = Spaces.create_dss_buffer(y0)
117+
Spaces.weighted_dss!(y0, dss_buffer)
95118

96119
@test parent(yx) parent(y0)
97120

98-
p = @allocated Spaces.weighted_dss!(y0, dss2_buffer)
121+
p = @allocated Spaces.weighted_dss!(y0, dss_buffer)
99122
@show p
100123
#@test p == 0
101124
end

0 commit comments

Comments
 (0)