Skip to content

Commit 1a656a8

Browse files
Merge pull request #1937 from CliMA/ck/inference2
Improve inference in grid constructors
2 parents 2f26e9a + 5cd7a71 commit 1a656a8

File tree

2 files changed

+38
-20
lines changed

2 files changed

+38
-20
lines changed

src/Grids/spectralelement.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,32 @@ function SpectralElementGrid1D(
3939
end
4040
end
4141

42-
function _SpectralElementGrid1D(
42+
_SpectralElementGrid1D(
4343
topology::Topologies.IntervalTopology,
4444
quadrature_style::Quadratures.QuadratureStyle,
45+
) = _SpectralElementGrid1D(
46+
topology,
47+
quadrature_style,
48+
Val(Topologies.nlocalelems(topology)),
4549
)
50+
51+
function _SpectralElementGrid1D(
52+
topology::Topologies.IntervalTopology,
53+
quadrature_style::Quadratures.QuadratureStyle,
54+
::Val{Nh},
55+
) where {Nh}
4656
global_geometry = Geometry.CartesianGlobalGeometry()
4757
CoordType = Topologies.coordinate_type(topology)
4858
AIdx = Geometry.coordinate_axis(CoordType)
4959
FT = eltype(CoordType)
50-
nelements = Topologies.nlocalelems(topology)
51-
Nh = nelements
5260
Nq = Quadratures.degrees_of_freedom(quadrature_style)
5361

5462
LG = Geometry.LocalGeometry{AIdx, CoordType, FT, SMatrix{1, 1, FT, 1}}
5563
local_geometry = DataLayouts.IFH{LG, Nq, Nh}(Array{FT})
5664
quad_points, quad_weights =
5765
Quadratures.quadrature_points(FT, quadrature_style)
5866

59-
for elem in 1:nelements
67+
for elem in 1:Nh
6068
local_geometry_slab = slab(local_geometry, elem)
6169
for i in 1:Nq
6270
ξ = quad_points[i]
@@ -182,12 +190,24 @@ function get_CoordType2D(topology)
182190
end
183191
end
184192

185-
function _SpectralElementGrid2D(
193+
_SpectralElementGrid2D(
186194
topology::Topologies.Topology2D,
187195
quadrature_style::Quadratures.QuadratureStyle;
188196
enable_bubble::Bool,
197+
) = _SpectralElementGrid2D(
198+
topology,
199+
quadrature_style,
200+
Val(Topologies.nlocalelems(topology));
201+
enable_bubble,
189202
)
190203

204+
function _SpectralElementGrid2D(
205+
topology::Topologies.Topology2D,
206+
quadrature_style::Quadratures.QuadratureStyle,
207+
::Val{Nh};
208+
enable_bubble::Bool,
209+
) where {Nh}
210+
191211
# 1. compute localgeom for local elememts
192212
# 2. ghost exchange of localgeom
193213
# 3. do a round of dss on WJs
@@ -213,8 +233,6 @@ function _SpectralElementGrid2D(
213233
end
214234
CoordType2D = get_CoordType2D(topology)
215235
AIdx = Geometry.coordinate_axis(CoordType2D)
216-
nlelems = Topologies.nlocalelems(topology)
217-
Nh = nlelems
218236
ngelems = Topologies.nghostelems(topology)
219237
Nq = Quadratures.degrees_of_freedom(quadrature_style)
220238
high_order_quadrature_style = Quadratures.GLL{Nq * 2}()

test/Spaces/opt_spaces.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#=
22
julia --project=.buildkite
3+
ENV["CLIMACOMMS_DEVICE"] = "CUDA";
34
using Revise; include(joinpath("test", "Spaces", "opt_spaces.jl"))
45
=#
56
import ClimaCore
6-
import ClimaCore: Spaces, Grids
7+
import ClimaCore: Spaces, Grids, Topologies
78
using Test
89
include(
910
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
@@ -34,19 +35,19 @@ end
3435
#! format: off
3536
if ClimaComms.device(context) isa ClimaComms.CUDADevice
3637
test_n_failures(86, TU.PointSpace, context)
37-
test_n_failures(144, TU.SpectralElementSpace1D, context)
38+
test_n_failures(141, TU.SpectralElementSpace1D, context)
3839
test_n_failures(1141, TU.SpectralElementSpace2D, context)
39-
test_n_failures(123, TU.ColumnCenterFiniteDifferenceSpace, context)
40-
test_n_failures(123, TU.ColumnFaceFiniteDifferenceSpace, context)
41-
test_n_failures(1131, TU.SphereSpectralElementSpace, context)
42-
test_n_failures(1139, TU.CenterExtrudedFiniteDifferenceSpace, context)
43-
test_n_failures(1139, TU.FaceExtrudedFiniteDifferenceSpace, context)
40+
test_n_failures(3, TU.ColumnCenterFiniteDifferenceSpace, context)
41+
test_n_failures(4, TU.ColumnFaceFiniteDifferenceSpace, context)
42+
test_n_failures(1147, TU.SphereSpectralElementSpace, context)
43+
test_n_failures(1146, TU.CenterExtrudedFiniteDifferenceSpace, context)
44+
test_n_failures(1146, TU.FaceExtrudedFiniteDifferenceSpace, context)
4445
else
4546
test_n_failures(0, TU.PointSpace, context)
4647
test_n_failures(137, TU.SpectralElementSpace1D, context)
4748
test_n_failures(310, TU.SpectralElementSpace2D, context)
48-
test_n_failures(118, TU.ColumnCenterFiniteDifferenceSpace, context)
49-
test_n_failures(118, TU.ColumnFaceFiniteDifferenceSpace, context)
49+
test_n_failures(4, TU.ColumnCenterFiniteDifferenceSpace, context)
50+
test_n_failures(5, TU.ColumnFaceFiniteDifferenceSpace, context)
5051
test_n_failures(316, TU.SphereSpectralElementSpace, context)
5152
test_n_failures(321, TU.CenterExtrudedFiniteDifferenceSpace, context)
5253
test_n_failures(321, TU.FaceExtrudedFiniteDifferenceSpace, context)
@@ -56,11 +57,10 @@ end
5657
# separately:
5758

5859
space = TU.CenterExtrudedFiniteDifferenceSpace(Float32; context=ClimaComms.context())
59-
# @test_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space); enable_bubble=false)
60-
61-
result = JET.@report_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space); enable_bubble=false)
60+
Nh = Val(Topologies.nlocalelems(Spaces.topology(space)))
61+
result = JET.@report_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space), Val(Nh); enable_bubble=false)
6262
n_found = length(JET.get_reports(result.analyzer, result.result))
63-
n_allowed = 187
63+
n_allowed = 0
6464
@test n_found n_allowed
6565
if n_found < n_allowed
6666
@info "Inference may have improved for _SpectralElementGrid2D: (n_found, n_allowed) = ($n_found, $n_allowed)"

0 commit comments

Comments
 (0)