Skip to content

Commit cbbee35

Browse files
committed
implicit flat fields
1 parent 8fd60da commit cbbee35

File tree

5 files changed

+62
-9
lines changed

5 files changed

+62
-9
lines changed

src/MatrixFields/MatrixFields.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{
8888
} where {
8989
V <: AbstractData{<:BandMatrixRow},
9090
S <: Union{
91-
Spaces.FiniteDifferenceSpace,
92-
Spaces.ExtrudedFiniteDifferenceSpace,
91+
Spaces.AbstractSpace,
9392
Operators.PlaceholderSpace, # so that this can exist inside cuda kernels
9493
},
9594
}

src/MatrixFields/matrix_shape.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,34 @@ struct Square <: AbstractMatrixShape end
33
struct FaceToCenter <: AbstractMatrixShape end
44
struct CenterToFace <: AbstractMatrixShape end
55

6+
matrix_shape(matrix_field) = matrix_shape(matrix_field, axes(matrix_field))
7+
68
"""
79
matrix_shape(matrix_field, [matrix_space])
810
9-
Returns either `Square()`, `FaceToCenter()`, or `CenterToFace()`, depending on
10-
whether the diagonal indices of `matrix_field` are `Int`s or `PlusHalf`s and
11-
whether `matrix_space` is on cell centers or cell faces. By default,
11+
Returns the matrix shape for a matrix field defined on the `matrix_space`. By default,
1212
`matrix_space` is set to `axes(matrix_field)`.
13+
14+
When the matrix_space is a finite difference space (extruded or otherwise): the shape is
15+
either `Square()`, `FaceToCenter()`, or `CenterToFace()`, depending on
16+
whether the diagonal indices of `matrix_field` are `Int`s or `PlusHalf`s and
17+
whether `matrix_space` is on cell centers or cell faces.
18+
19+
When the matrix_space is a spectral element or point space: only a Square() shape is supported.
1320
"""
14-
matrix_shape(matrix_field, matrix_space = axes(matrix_field)) = _matrix_shape(
21+
matrix_shape(matrix_field, matrix_space) = _matrix_shape(
1522
eltype(outer_diagonals(eltype(matrix_field))),
1623
matrix_space.staggering,
1724
)
1825

26+
function matrix_shape(
27+
matrix_field,
28+
matrix_space::Union{Spaces.AbstractSpectralElementSpace, Spaces.PointSpace},
29+
)
30+
@assert eltype(matrix_field) <: DiagonalMatrixRow
31+
Square()
32+
end
33+
1934
_matrix_shape(::Type{Int}, _) = Square()
2035
_matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellCenter) = FaceToCenter()
2136
_matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellFace) = CenterToFace()

src/MatrixFields/single_field_solver.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ end
9898

9999
function _single_field_solve_col!(
100100
::ClimaComms.AbstractCPUDevice,
101-
cache::Fields.ColumnField,
102-
x::Fields.ColumnField,
101+
cache,
102+
x,
103103
A,
104-
b::Fields.ColumnField,
104+
b,
105105
)
106106
if A isa Fields.ColumnField
107107
band_matrix_solve!(

test/MatrixFields/flat_spaces.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import ClimaCore
2+
include(
3+
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
4+
)
5+
import .TestUtilities as TU
6+
7+
include("matrix_field_test_utils.jl")
8+
import ClimaCore.MatrixFields: @name,
9+
10+
@testset "Matrix Fields with Spectral Element and Point Spaces" begin
11+
get_j_field(space, FT) = fill(MatrixFields.DiagonalMatrixRow(FT(1)), space)
12+
13+
implicit_vars = (@name(tmp.v1), @name(tmp.v2))
14+
for FT in (Float32, Float64)
15+
comms_ctx = ClimaComms.SingletonCommsContext(comms_device)
16+
ps = TU.PointSpace(FT; context = comms_ctx)
17+
ses = TU.SpectralElementSpace2D(FT; context = comms_ctx)
18+
v1 = Fields.zeros(ps)
19+
v2 = Fields.zeros(ses)
20+
Y = Fields.FieldVector(; :tmp => (; :v1 => v1, :v2 => v2))
21+
implicit_blocks = MatrixFields.unrolled_map(
22+
var ->
23+
(var, var) =>
24+
get_j_field(axes(MatrixFields.get_field(Y, var)), FT),
25+
implicit_vars,
26+
)
27+
matrix = MatrixFields.FieldMatrix(implicit_blocks...)
28+
alg = MatrixFields.BlockDiagonalSolve()
29+
solver = MatrixFields.FieldMatrixSolver(alg, matrix, Y)
30+
b1 = random_field(FT, ps)
31+
b2 = random_field(FT, ses)
32+
x = similar(Y)
33+
b = Fields.FieldVector(; :tmp => (; :v1 => b1, :v2 => b2))
34+
MatrixFields.field_matrix_solve!(solver, x, matrix, b)
35+
@test x == b
36+
end
37+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ UnitTest("MatrixFields - non-scalar broadcasting (1)" ,"MatrixFields/matrix_fiel
8181
UnitTest("MatrixFields - non-scalar broadcasting (2)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_2.jl"),
8282
UnitTest("MatrixFields - non-scalar broadcasting (3)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl"),
8383
UnitTest("MatrixFields - non-scalar broadcasting (4)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_4.jl"),
84+
UnitTest("MatrixFields - flat spaces" ,"MatrixFields/flat_spaces.jl"),
85+
8486
# UnitTest("MatrixFields - matrix field broadcast" ,"MatrixFields/matrix_field_broadcasting.jl"), # too long
8587
# UnitTest("MatrixFields - operator matrices" ,"MatrixFields/operator_matrices.jl"), # too long
8688
# UnitTest("MatrixFields - field matrix solvers" ,"MatrixFields/field_matrix_solvers.jl"), # too long

0 commit comments

Comments
 (0)