Skip to content

Commit 3589ca6

Browse files
committed
working state
1 parent e49e341 commit 3589ca6

File tree

2 files changed

+114
-56
lines changed

2 files changed

+114
-56
lines changed

src/MatrixFields/field_name_dict.jl

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,24 @@ get_internal_key(child_name_pair::FieldNamePair, name_pair::FieldNamePair) = (
150150
)
151151

152152
get_internal_entry(entry, name::FieldName, key_error) = get_field(entry, name)
153-
get_internal_entry(entry, name_pair::FieldNamePair, key_error) =
154-
name_pair == (@name(), @name()) ? entry : throw(key_error)
155-
function get_internal_entry(
156-
entry::ScalingFieldMatrixEntry,
153+
# call get_internal_entry on scaling value, and rebuild entry container
154+
get_internal_entry(entry::UniformScaling, name_pair::FieldNamePair, key_error) =
155+
UniformScaling(
156+
get_internal_entry(scaling_value(entry), name_pair, key_error),
157+
)
158+
get_internal_entry(
159+
entry::DiagonalMatrixRow,
157160
name_pair::FieldNamePair,
158161
key_error,
162+
) = DiagonalMatrixRow(
163+
get_internal_entry(scaling_value(entry), name_pair, key_error),
159164
)
160-
T = eltype(entry)
165+
# get_internal_entry to be used on the values held inside a `BandMatrixRow`
166+
function get_internal_entry(
167+
entry::T,
168+
name_pair::FieldNamePair,
169+
key_error,
170+
) where {T}
161171
if name_pair == (@name(), @name())
162172
return entry
163173
elseif T <: Geometry.Axis2Tensor &&
@@ -170,45 +180,39 @@ function get_internal_entry(
170180
row_index = extract_first(internal_row_name)
171181
col_index = extract_first(internal_col_name)
172182
return get_internal_entry(
173-
DiagonalMatrixRow(scaling_value(entry)[row_index, col_index]),
183+
entry[row_index, col_index],
174184
(drop_first(internal_row_name), drop_first(internal_col_name)),
175185
key_error,
176186
)
177-
elseif T <: Geometry.AdjointAxisVector
187+
elseif T <: Geometry.AdjointAxisVector # bypass parent for adjoint vectors
178188
return get_internal_entry(
179-
DiagonalMatrixRow(getfield(scaling_value(entry), :parent)),
189+
getfield(entry, :parent),
180190
name_pair,
181191
key_error,
182192
)
183-
else
184-
child_name, remaining_chain =
185-
if name_pair[1] != @name() &&
186-
extract_first(name_pair[1]) in fieldnames(T)
187-
@inline (
188-
extract_first(name_pair[1]),
189-
(drop_first(name_pair[1]), name_pair[2]),
190-
)
191-
elseif name_pair[2] != @name() &&
192-
extract_first(name_pair[2]) in fieldnames(T)
193-
@inline (
194-
extract_first(name_pair[2]),
195-
(name_pair[1], drop_first(name_pair[2])),
196-
)
197-
elseif !any(isequal(@name()), name_pair) # implicit tensor structure
198-
return get_internal_entry(
199-
extract_first(name_pair[1]) == extract_first(name_pair[2]) ?
200-
entry : zero(entry),
201-
(drop_first(name_pair[1]), drop_first(name_pair[2])),
202-
key_error,
203-
)
204-
else
205-
throw(key_error)
206-
end
193+
elseif name_pair[1] != @name() &&
194+
extract_first(name_pair[1]) in fieldnames(T)
195+
return get_internal_entry(
196+
getfield(entry, extract_first(name_pair[1])),
197+
(drop_first(name_pair[1]), name_pair[2]),
198+
key_error,
199+
)
200+
elseif name_pair[2] != @name() &&
201+
extract_first(name_pair[2]) in fieldnames(T)
202+
return get_internal_entry(
203+
getfield(entry, extract_first(name_pair[2])),
204+
(name_pair[1], drop_first(name_pair[2])),
205+
key_error,
206+
)
207+
elseif !any(isequal(@name()), name_pair) # implicit tensor structure
207208
return get_internal_entry(
208-
DiagonalMatrixRow(getfield(scaling_value(entry), child_name)),
209-
remaining_chain,
209+
extract_first(name_pair[1]) == extract_first(name_pair[2]) ? entry :
210+
zero(entry),
211+
(drop_first(name_pair[1]), drop_first(name_pair[2])),
210212
key_error,
211213
)
214+
else
215+
throw(key_error)
212216
end
213217
end
214218
function get_internal_entry(
@@ -241,18 +245,19 @@ function get_internal_entry(
241245
)
242246
return Fields.Field(values, axes(entry))
243247
elseif apply_zero
248+
zero_value = zero(target_type)
244249
return Base.broadcasted(entry) do matrix_row
245250
map(matrix_row) do matrix_row_entry
246-
zero(target_type)
251+
# zero(target_type)
252+
zero_value
247253
end
248254
end
255+
elseif target_type == S
256+
return entry
249257
else
250258
return Base.broadcasted(entry) do matrix_row
251259
map(matrix_row) do matrix_row_entry
252-
broadcasted_get_field(
253-
broadcasted_get_field(matrix_row_entry, name_pair[1]),
254-
name_pair[2],
255-
)
260+
get_internal_entry(matrix_row_entry, name_pair, key_error)
256261
end
257262
end
258263
end
@@ -329,7 +334,8 @@ function field_offset_and_type(
329334
key_error,
330335
) where {S, T}
331336
name_pair == (@name(), @name()) && return (0, S, false) # base case
332-
if S <: Geometry.Axis2Tensor # special case to calculate index
337+
if S <: Geometry.Axis2Tensor &&
338+
all(n -> is_child_name(n, @name(components.data)), name_pair)# special case to calculate index
333339
(name_pair[1] == @name() || name_pair[2] == @name()) && throw(key_error)
334340
internal_row_name =
335341
extract_internal_name(name_pair[1], @name(components.data))
@@ -372,7 +378,7 @@ function field_offset_and_type(
372378
field_offset_and_type(
373379
(drop_first(name_pair[1]), drop_first(name_pair[2])),
374380
T,
375-
fieldtype(S, 1),
381+
S,
376382
key_error,
377383
)
378384
return (
@@ -419,8 +425,8 @@ when indexing `dict`.
419425
"""
420426
function get_scalar_keys(dict::FieldMatrix)
421427
keys_tuple = unrolled_flatmap(keys(dict).values) do outer_key
422-
@inline unrolled_map(get_scalar_keys(eltype(dict[outer_key]))) do inner_key
423-
@inline (
428+
unrolled_map(get_scalar_keys(eltype(dict[outer_key]))) do inner_key
429+
(
424430
append_internal_name(outer_key[1], inner_key[1]),
425431
append_internal_name(outer_key[2], inner_key[2]),
426432
)
@@ -463,17 +469,6 @@ function get_scalar_keys(::Type{T}) where {T}
463469
inner_key[2],
464470
)
465471
end
466-
# return unrolled_flatmap((:components,)) do inner_field
467-
# @inline unrolled_map(
468-
# # get_scalar_keys(fieldtype(T, inner_field)),
469-
# get_scalar_keys(fieldtype(T, :components)),
470-
# ) do inner_key
471-
# @inline (
472-
# append_internal_name(FieldName(inner_field), inner_key[1]),
473-
# inner_key[2],
474-
# )
475-
# end
476-
# end
477472
else
478473
return unrolled_flatmap(fieldnames(T)) do inner_name
479474
unrolled_map(get_scalar_keys(fieldtype(T, inner_name))) do inner_key

test/MatrixFields/scalar_fieldmatrix.jl

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ include("matrix_field_test_utils.jl")
2424
::Type{S},
2525
expected_offset,
2626
::Type{E},
27-
key_error,
27+
key_error;
28+
apply_zero = false,
2829
) where {T, S, E}
2930
@test_all MatrixFields.field_offset_and_type(name, T, S, key_error) ==
30-
(expected_offset, E, false)
31+
(expected_offset, E, apply_zero)
3132
end
3233
test_field_offset_and_type(
3334
(@name(x), @name()),
@@ -64,6 +65,29 @@ include("matrix_field_test_utils.jl")
6465
TwoFields{FT, Singleton{FT}},
6566
KeyError(@name(y.y.x)),
6667
)
68+
test_field_offset_and_type(
69+
(@name(y.k), @name(y.k)),
70+
FT,
71+
TwoFields{
72+
TwoFields{FT, FT},
73+
TwoFields{FT, TwoFields{FT, Singleton{FT}}},
74+
},
75+
3,
76+
TwoFields{FT, Singleton{FT}},
77+
KeyError(@name(y.y.x)),
78+
)
79+
test_field_offset_and_type(
80+
(@name(y.k.g), @name(y.k.l)),
81+
FT,
82+
TwoFields{
83+
TwoFields{FT, FT},
84+
TwoFields{FT, TwoFields{FT, Singleton{FT}}},
85+
},
86+
3,
87+
TwoFields{FT, Singleton{FT}},
88+
KeyError(@name(y.y.x)),
89+
apply_zero = true,
90+
)
6791
test_field_offset_and_type(
6892
(@name(y.y), @name(x)),
6993
FT,
@@ -117,3 +141,42 @@ end
117141
@test_opt MatrixFields.scalar_fieldmatrix(A)
118142
end
119143
end
144+
145+
@testset "cursed implicit tensor structure optimization indexing" begin
146+
FT = Float64
147+
center_space = test_spaces(FT)[1]
148+
for (maybe_copy, maybe_to_field) in
149+
((identity, identity), (copy, x -> fill(x, center_space)))
150+
A = MatrixFields.FieldMatrix(
151+
(@name(c.uₕ), @name(c.uₕ)) =>
152+
maybe_to_field(DiagonalMatrixRow(FT(2))),
153+
(@name(foo), @name(bar)) => maybe_to_field(
154+
DiagonalMatrixRow(
155+
Geometry.Covariant12Vector(FT(1), FT(2)) *
156+
Geometry.Contravariant12Vector(FT(1), FT(2))',
157+
),
158+
),
159+
)
160+
@test_all A[(
161+
@name(c.uₕ.components.data.:1),
162+
@name(c.uₕ.components.data.:1)
163+
)] == A[(@name(c.uₕ), @name(c.uₕ))]
164+
@test maybe_copy(
165+
A[(@name(c.uₕ.components.data.:2), @name(c.uₕ.components.data.:1))],
166+
) == maybe_to_field(DiagonalMatrixRow(FT(0)))
167+
@test maybe_copy(A[(@name(foo.dog), @name(bar.dog))]) ==
168+
A[(@name(foo), @name(bar))]
169+
@test maybe_copy(A[(@name(foo.cat), @name(bar.dog))]) ==
170+
zero(A[(@name(foo), @name(bar))])
171+
@test A[(
172+
@name(foo.dog.components.data.:1),
173+
@name(bar.dog.components.data.:2)
174+
)] == maybe_to_field(DiagonalMatrixRow(FT(2)))
175+
@test maybe_copy(
176+
A[(
177+
@name(foo.dog.components.data.:1),
178+
@name(bar.cat.components.data.:2)
179+
)],
180+
) == maybe_to_field(DiagonalMatrixRow(FT(0)))
181+
end
182+
end

0 commit comments

Comments
 (0)