Skip to content

Commit 8bf9524

Browse files
Merge pull request #1680 from CliMA/ck/local_geometry_type
Define `local_geometry_type`
2 parents d352572 + 759f794 commit 8bf9524

16 files changed

+128
-5
lines changed

src/Fields/Fields.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import ..DataLayouts: DataLayouts, AbstractData, DataStyle
77
import ..Domains
88
import ..Topologies
99
import ..Quadratures
10-
import ..Grids: ColumnIndex
10+
import ..Grids: ColumnIndex, local_geometry_type
1111
import ..Spaces: Spaces, AbstractSpace, AbstractPointSpace
1212
import ..Geometry: Geometry, Cartesian12Vector
1313
import ..Utilities: PlusHalf, half, UnrolledFunctions
@@ -39,6 +39,8 @@ Field(values::V, space::S) where {V <: AbstractData, S <: AbstractSpace} =
3939
Field(::Type{T}, space::S) where {T, S <: AbstractSpace} =
4040
Field(similar(Spaces.coordinates_data(space), T), space)
4141

42+
local_geometry_type(::Field{V, S}) where {V, S} = local_geometry_type(S)
43+
4244
ClimaComms.context(field::Field) = ClimaComms.context(axes(field))
4345

4446
ClimaComms.context(topology::Topologies.Topology2D) = topology.context

src/Grids/Grids.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ If the grid is not staggered, `staggering` should be `nothing`.
4949
"""
5050
function local_geometry_data end
5151

52+
"""
53+
Grids.local_geometry_type(::Type)
54+
55+
Get the `LocalGeometry` type.
56+
"""
57+
function local_geometry_type end
58+
59+
# Fallback, but this requires user error-handling
60+
local_geometry_type(::Type{T}) where {T} = Union{}
61+
5262
function local_dss_weights end
5363
function quadrature_style end
5464
function vertical_topology end

src/Grids/column.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct ColumnGrid{
3636
colidx::C
3737
end
3838

39+
local_geometry_type(::Type{ColumnGrid{G, C}}) where {G, C} =
40+
local_geometry_type(G)
3941

4042
column(grid::AbstractExtrudedFiniteDifferenceGrid, colidx::ColumnIndex) =
4143
ColumnGrid(grid, colidx)

src/Grids/extruded.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ mutable struct ExtrudedFiniteDifferenceGrid{
3737
face_local_geometry::LG
3838
end
3939

40+
local_geometry_type(
41+
::Type{ExtrudedFiniteDifferenceGrid{H, V, A, GG, LG}},
42+
) where {H, V, A, GG, LG} = eltype(LG) # calls eltype from DataLayouts
43+
4044
function ExtrudedFiniteDifferenceGrid(
4145
horizontal_grid::Union{SpectralElementGrid1D, SpectralElementGrid2D},
4246
vertical_grid::FiniteDifferenceGrid,
@@ -125,7 +129,6 @@ vertical_topology(grid::ExtrudedFiniteDifferenceGrid) =
125129
local_dss_weights(grid::ExtrudedFiniteDifferenceGrid) =
126130
local_dss_weights(grid.horizontal_grid)
127131

128-
129132
local_geometry_data(grid::AbstractExtrudedFiniteDifferenceGrid, ::CellCenter) =
130133
grid.center_local_geometry
131134
local_geometry_data(grid::AbstractExtrudedFiniteDifferenceGrid, ::CellFace) =
@@ -147,6 +150,10 @@ struct DeviceExtrudedFiniteDifferenceGrid{VT, Q, GG, LG} <:
147150
face_local_geometry::LG
148151
end
149152

153+
local_geometry_type(
154+
::Type{DeviceExtrudedFiniteDifferenceGrid{VT, Q, GG, LG}},
155+
) where {VT, Q, GG, LG} = eltype(LG) # calls eltype from DataLayouts
156+
150157
Adapt.adapt_structure(to, grid::ExtrudedFiniteDifferenceGrid) =
151158
DeviceExtrudedFiniteDifferenceGrid(
152159
Adapt.adapt(to, vertical_topology(grid)),

src/Grids/finitedifference.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ FiniteDifferenceGrid(mesh::Meshes.IntervalMesh) =
168168
topology(grid::FiniteDifferenceGrid) = grid.topology
169169
vertical_topology(grid::FiniteDifferenceGrid) = grid.topology
170170

171+
local_geometry_type(::Type{FiniteDifferenceGrid{T, GG, LG}}) where {T, GG, LG} =
172+
eltype(LG) # calls eltype from DataLayouts
173+
171174
local_geometry_data(grid::FiniteDifferenceGrid, ::CellCenter) =
172175
grid.center_local_geometry
173176
local_geometry_data(grid::FiniteDifferenceGrid, ::CellFace) =
@@ -182,6 +185,10 @@ struct DeviceFiniteDifferenceGrid{T, GG, LG} <: AbstractFiniteDifferenceGrid
182185
face_local_geometry::LG
183186
end
184187

188+
local_geometry_type(
189+
::Type{DeviceFiniteDifferenceGrid{T, GG, LG}},
190+
) where {T, GG, LG} = eltype(LG) # calls eltype from DataLayouts
191+
185192
Adapt.adapt_structure(to, grid::FiniteDifferenceGrid) =
186193
DeviceFiniteDifferenceGrid(
187194
Adapt.adapt(to, grid.topology),

src/Grids/level.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ topology(levelgrid::LevelGrid) = topology(levelgrid.full_grid)
1818

1919
local_dss_weights(grid::LevelGrid) = local_dss_weights(grid.full_grid)
2020

21+
local_geometry_type(::Type{LevelGrid{G, L}}) where {G, L} =
22+
local_geometry_type(G)
2123
local_geometry_data(levelgrid::LevelGrid{<:Any, Int}, ::Nothing) = level(
2224
local_geometry_data(levelgrid.full_grid, CellCenter()),
2325
levelgrid.level,

src/Grids/spectralelement.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ mutable struct SpectralElementGrid1D{
2121
dss_weights::D
2222
end
2323

24+
local_geometry_type(
25+
::Type{SpectralElementGrid1D{T, Q, GG, LG}},
26+
) where {T, Q, GG, LG} = eltype(LG) # calls eltype from DataLayouts
27+
2428
# non-view grids are cached based on their input arguments
2529
# this means that if data is saved in two different files, reloading will give fields which live on the same grid
2630
function SpectralElementGrid1D(
@@ -118,6 +122,10 @@ mutable struct SpectralElementGrid2D{
118122
boundary_surface_geometries::BS
119123
end
120124

125+
local_geometry_type(
126+
::Type{SpectralElementGrid2D{T, Q, GG, LG, D, IS, BS}},
127+
) where {T, Q, GG, LG, D, IS, BS} = eltype(LG) # calls eltype from DataLayouts
128+
121129
"""
122130
SpectralElementSpace2D(topology, quadrature_style; enable_bubble)
123131

src/MatrixFields/MatrixFields.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ import ..DataLayouts: AbstractData
6060
import ..Geometry
6161
import ..Topologies
6262
import ..Spaces
63+
import ..Spaces: local_geometry_type
6364
import ..Fields
6465
import ..Operators
6566

src/MatrixFields/matrix_multiplication.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,29 @@ function Operators.right_interior_idx(
270270
end
271271
end
272272

273+
pick_inferred_type(
274+
::Type{Union{}},
275+
::Type{Y},
276+
) where {Y <: Geometry.LocalGeometry} = Y
277+
pick_inferred_type(
278+
::Type{X},
279+
::Type{Union{}},
280+
) where {X <: Geometry.LocalGeometry} = X
281+
pick_inferred_type(::Type{T}, ::Type{T}) where {T <: Geometry.LocalGeometry} = T
282+
pick_inferred_type(::Type{Union{}}, ::Type{Union{}}) =
283+
error("Both LGs are not inferred")
284+
pick_inferred_type(::Type{X}, ::Type{Y}) where {X, Y} =
285+
error("LGs do not match: X=$X, Y=$Y")
286+
273287
function Operators.return_eltype(
274288
::MultiplyColumnwiseBandMatrixField,
275289
matrix1,
276290
arg,
277291
)
292+
# LG1 = local_geometry_type(typeof(axes(matrix1)))
293+
# LG2 = local_geometry_type(typeof(axes(arg)))
294+
# LG = pick_inferred_type(LG1, LG2)
295+
# return Operators.return_eltype(op, matrix1, arg, LG)
278296
eltype(matrix1) <: BandMatrixRow || error(
279297
"The first argument of ⋅ must have elements of type BandMatrixRow, but \
280298
the given argument has elements of type $(eltype(matrix1))",
@@ -293,6 +311,39 @@ function Operators.return_eltype(
293311
end
294312
end
295313

314+
function Operators.return_eltype(
315+
::MultiplyColumnwiseBandMatrixField,
316+
matrix1,
317+
arg,
318+
::Type{LG},
319+
) where {LG}
320+
eltype(matrix1) <: BandMatrixRow || error(
321+
"The first argument of ⋅ must have elements of type BandMatrixRow, but \
322+
the given argument has elements of type $(eltype(matrix1))",
323+
)
324+
if eltype(arg) <: BandMatrixRow # matrix-matrix multiplication
325+
matrix2 = arg
326+
ld1, ud1 = outer_diagonals(eltype(matrix1))
327+
ld2, ud2 = outer_diagonals(eltype(matrix2))
328+
prod_ld, prod_ud = ld1 + ld2, ud1 + ud2
329+
prod_value_type = Base.promote_op(
330+
rmul_with_projection,
331+
eltype(eltype(matrix1)),
332+
eltype(eltype(matrix2)),
333+
LG,
334+
)
335+
return band_matrix_row_type(prod_ld, prod_ud, prod_value_type)
336+
else # matrix-vector multiplication
337+
vector = arg
338+
prod_value_type = Base.promote_op(
339+
rmul_with_projection,
340+
eltype(eltype(matrix1)),
341+
eltype(vector),
342+
LG,
343+
)
344+
end
345+
end
346+
296347
Operators.return_space(::MultiplyColumnwiseBandMatrixField, space1, space2) =
297348
space1
298349

@@ -314,6 +365,8 @@ boundary_modified_ud(::BottomRightMatrixCorner, ud, column_space, i) =
314365
# matrix field broadcast expressions to take roughly 3 or 4 times longer to
315366
# evaluate, but this is less significant than the decrease in compilation time.
316367
function multiply_matrix_at_index(loc, space, idx, hidx, matrix1, arg, bc)
368+
# lg = Geometry.LocalGeometry(space, idx, hidx)
369+
# prod_type = Operators.return_eltype(⋅, matrix1, arg, typeof(lg))
317370
prod_type = Operators.return_eltype(, matrix1, arg)
318371

319372
column_space1 = column_axes(matrix1, space)

src/MatrixFields/single_field_solver.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@ x_eltype(A::ColumnwiseBandMatrixField, b) =
1313
x_eltype(eltype(eltype(A)), eltype(b))
1414
x_eltype(::Type{T_A}, ::Type{T_b}) where {T_A, T_b} =
1515
rmul_return_type(inv_return_type(T_A), T_b)
16+
# Base.promote_op(rmul_with_projection, inv_return_type(T_A), T_b, LG)
1617

17-
unit_eltype(A::UniformScaling) = unit_eltype(eltype(A))
18-
unit_eltype(A::ColumnwiseBandMatrixField) = unit_eltype(eltype(eltype(A)))
19-
unit_eltype(::Type{T_A}) where {T_A} =
18+
unit_eltype(A::UniformScaling) = eltype(A)
19+
unit_eltype(A::ColumnwiseBandMatrixField) =
20+
unit_eltype(eltype(eltype(A)), local_geometry_type(A))
21+
unit_eltype(::Type{T_A}, ::Type{LG}) where {T_A, LG} =
2022
rmul_return_type(inv_return_type(T_A), T_A)
23+
# Base.promote_op(rmul_with_projection, inv_return_type(T_A), T_A, LG)
2124

2225
################################################################################
2326

0 commit comments

Comments
 (0)