Skip to content

Commit 65b036a

Browse files
committed
Add scalar_fieldmatrix
Add a function to convert a FieldMatrix where each matrix entry has an eltype of some struct into a FieldMatrix where each entry has an eltype of a scalar. Add additional tests for scalar_matrixfields Use @test_all in tests
1 parent b47dffb commit 65b036a

File tree

6 files changed

+441
-12
lines changed

6 files changed

+441
-12
lines changed

docs/src/matrix_fields.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ preconditioner_cache
8989
check_preconditioner
9090
lazy_or_concrete_preconditioner
9191
apply_preconditioner
92+
scalar_keys
93+
get_field_first_index_offset
94+
broadcasted_get_field_type
9295
```
9396

9497
## Utilities
@@ -98,4 +101,5 @@ column_field2array
98101
column_field2array_view
99102
field2arrays
100103
field2arrays_view
104+
scalar_fieldmatrix
101105
```

src/MatrixFields/MatrixFields.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import ..Utilities: PlusHalf, half
5858
import ..RecursiveApply:
5959
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
6060
import ..RecursiveApply: , ,
61+
import ..DataLayouts
6162
import ..DataLayouts: AbstractData
6263
import ..DataLayouts: vindex
6364
import ..Geometry

src/MatrixFields/field_name.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,24 @@ get_field(x, ::FieldName{()}) = x
5959
get_field(x, name::FieldName) =
6060
get_field(getproperty(x, extract_first(name)), drop_first(name))
6161

62+
"""
63+
broadcasted_get_field_type(::Type{X}, name::FieldName)
64+
65+
Returns the type of the field accessed by `name` in the type `X`.
66+
"""
67+
broadcasted_get_field_type(::Type{X}, ::FieldName{()}) where {X} = X
68+
broadcasted_get_field_type(::Type{X}, name::FieldName) where {X} =
69+
broadcasted_get_field_type(
70+
fieldtype(X, extract_first(name)),
71+
drop_first(name),
72+
)
73+
if hasfield(Method, :recursion_relation)
74+
dont_limit = (args...) -> true
75+
for m in methods(broadcasted_get_field_type)
76+
m.recursion_relation = dont_limit
77+
end
78+
end
79+
6280
broadcasted_has_field(::Type{X}, ::FieldName{()}) where {X} = true
6381
broadcasted_has_field(::Type{X}, name::FieldName) where {X} =
6482
extract_first(name) in fieldnames(X) &&

src/MatrixFields/field_name_dict.jl

Lines changed: 135 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,51 @@ function get_internal_entry(
181181
entry
182182
elseif name_pair[2] == @name() && broadcasted_has_field(T, name_pair[1])
183183
# multiplication case 2 or 4, second argument
184-
Base.broadcasted(entry) do matrix_row
185-
map(matrix_row) do matrix_row_entry
186-
broadcasted_get_field(matrix_row_entry, name_pair[1])
187-
end
188-
end # Note: This assumes that the entry is in a FieldMatrixBroadcasted.
184+
target_field_eltype = broadcasted_get_field_type(T, name_pair[1])
185+
if target_field_eltype <: Number
186+
T_band = eltype(entry)
187+
singleton_datalayout =
188+
DataLayouts.singleton(Fields.field_values(entry))
189+
# BandMatrixRow with same lowest diagonal and bandwidth as `entry`, with a scalar eltype
190+
scalar_band_type = BandMatrixRow{
191+
T_band.parameters[1],
192+
T_band.parameters[2],
193+
eltype(parent(entry)),
194+
}
195+
field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry))
196+
scalar_field_offset = get_field_first_index_offset(
197+
name_pair[1],
198+
target_field_eltype,
199+
T,
200+
)
201+
band_element_size = Int(div(sizeof(T), sizeof(target_field_eltype)))
202+
parent_indices = DataLayouts.to_data_specific_field(
203+
singleton_datalayout,
204+
(
205+
:,
206+
:,
207+
(1 + scalar_field_offset):band_element_size:field_dim_size,
208+
:,
209+
:,
210+
),
211+
)
212+
scalar_data = view(parent(entry), parent_indices...)
213+
values = DataLayouts.union_all(singleton_datalayout){
214+
scalar_band_type,
215+
Base.tail(
216+
DataLayouts.type_params(Fields.field_values(entry)),
217+
)...,
218+
}(
219+
scalar_data,
220+
)
221+
Fields.Field(values, axes(entry))
222+
else
223+
Base.broadcasted(entry) do matrix_row
224+
map(matrix_row) do matrix_row_entry
225+
broadcasted_get_field(matrix_row_entry, name_pair[1])
226+
end
227+
end # Note: This assumes that the entry is in a FieldMatrixBroadcasted.
228+
end
189229
else
190230
throw(key_error)
191231
end
@@ -237,6 +277,96 @@ function Base.one(matrix::FieldMatrix)
237277
return FieldNameDict(inferred_diagonal_keys, entries)
238278
end
239279

280+
"""
281+
get_field_first_index_offset(name::FieldName, ::Type{T}, ::Type{S})
282+
283+
Returns the offset of the the field with name `name` in an object of type `S`
284+
in multiples of `sizeof(T)`.
285+
"""
286+
function get_field_first_index_offset(
287+
name::FieldName,
288+
::Type{T},
289+
::Type{S},
290+
) where {T, S}
291+
if name == @name()
292+
return 0
293+
end
294+
child_name = extract_first(name)
295+
child_type = fieldtype(S, child_name)
296+
remaining_field_chain = drop_first(name)
297+
field_index =
298+
unrolled_filter(i -> fieldname(S, i) == child_name, 1:fieldcount(S))[1]
299+
return DataLayouts.fieldtypeoffset(T, S, field_index) +
300+
get_field_first_index_offset(remaining_field_chain, T, child_type)
301+
end
302+
if hasfield(Method, :recursion_relation)
303+
dont_limit = (args...) -> true
304+
for m in methods(get_field_first_index_offset)
305+
m.recursion_relation = dont_limit
306+
end
307+
end
308+
309+
"""
310+
get_scalar_keys(dict::FieldMatrix)
311+
312+
Returns a `FieldMatrixKeys` object that contains the keys of all the scalar
313+
entries in the `FieldMatrix` `dict`.
314+
"""
315+
function get_scalar_keys(dict::FieldMatrix)
316+
keys_tuple = unrolled_flatmap(keys(dict).values) do key
317+
_, entry = unrolled_filter(pair -> key == pair[1], pairs(dict))[1]
318+
entry =
319+
entry isa ColumnwiseBandMatrixField ? entry.entries.:(1) : entry
320+
unrolled_map(
321+
filtered_child_names(
322+
field -> eltype(field) <: Number,
323+
entry,
324+
@name()
325+
),
326+
) do name
327+
(append_internal_name(key[1], name), key[2])
328+
end
329+
end
330+
return FieldMatrixKeys(keys_tuple)
331+
end
332+
333+
"""
334+
scalar_fieldmatrix(field_matrix::FieldMatrix)
335+
336+
Constructs a `FieldNameDict` where the keys and entries are views
337+
of the entries of `field_matrix`, which corresponding to the
338+
scalar components of entries of `field_matrix`.
339+
340+
# Example usage
341+
```julia
342+
struct foo{T1, T2}
343+
a::T
344+
b::T2
345+
end
346+
mat1 = fill(DiagonalMatrixRow(ClimaCore.Geometry.Covariant12Vector(1.0, 2.0)), space)
347+
mat2 = fill(DiagonalMatrixRow(foo(foo(1.0, 2.0), 3.0)), space)
348+
A = MatrixFields.FieldMatrix(
349+
(@name(biz), @name(baz)) => mat1,
350+
(@name(bip), @name(bop)) => mat2,
351+
)
352+
A_scalar = MatrixFields.scalar_fieldmatrix(A)
353+
keys(A_scalar)
354+
# Output:
355+
# (@name(biz.components.data.:(1)), @name(baz))
356+
# (@name(biz.components.data.:(2)), @name(baz))
357+
# (@name(bip.a.a), @name(bop))
358+
# (@name(bip.a.b), @name(bop))
359+
# (@name(bip.b), @name(bop))
360+
```
361+
"""
362+
function scalar_fieldmatrix(field_matrix::FieldMatrix)
363+
scalar_keys = get_scalar_keys(field_matrix)
364+
entries = unrolled_map(scalar_keys.values) do key
365+
field_matrix[key]
366+
end
367+
return FieldNameDict(scalar_keys, entries)
368+
end
369+
240370
replace_name_tree(dict::FieldNameDict, name_tree) =
241371
FieldNameDict(replace_name_tree(keys(dict), name_tree), values(dict))
242372

test/MatrixFields/field_names.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,9 @@ end
770770
(@name(a), @name(a)) => -I_CT3XC3,
771771
)
772772

773-
for (vector, matrix, I_foo, I_a) in (
774-
(vector_of_scalars, matrix_of_scalars, I, I),
775-
(vector_of_vectors, matrix_of_tensors, I_C12XCT12, I_CT3XC3),
773+
for (vector, matrix, I_foo, I_a, is_scalar_test) in (
774+
(vector_of_scalars, matrix_of_scalars, I, I, true),
775+
(vector_of_vectors, matrix_of_tensors, I_C12XCT12, I_CT3XC3, false),
776776
)
777777
@test_all MatrixFields.field_vector_view(vector) ==
778778
MatrixFields.FieldVectorView(
@@ -842,10 +842,13 @@ end
842842
@test_all matrix[@name(a.c), @name(a.b)] == zero(I_a)
843843
@test_all matrix[@name(foo._value), @name(foo._value)] ==
844844
matrix[@name(foo), @name(foo)]
845-
846-
@test_all matrix[@name(foo._value), @name(a.b)] isa
847-
Base.AbstractBroadcasted
848-
@test Base.materialize(matrix[@name(foo._value), @name(a.b)]) == map(
845+
entry = matrix[@name(foo._value), @name(a.b)]
846+
@test_all entry isa (
847+
is_scalar_test ? MatrixFields.ColumnwiseBandMatrixField :
848+
Base.AbstractBroadcasted
849+
)
850+
entry = is_scalar_test ? entry : Base.materialize(entry)
851+
@test entry == map(
849852
row -> map(foo -> foo.value, row),
850853
matrix[@name(foo), @name(a.b)],
851854
)

0 commit comments

Comments
 (0)