@@ -150,14 +150,24 @@ get_internal_key(child_name_pair::FieldNamePair, name_pair::FieldNamePair) = (
150
150
)
151
151
152
152
get_internal_entry (entry, name:: FieldName , key_error) = get_field (entry, name)
153
- get_internal_entry (entry, name_pair:: FieldNamePair , key_error) =
154
- name_pair == (@name (), @name ()) ? entry : throw (key_error)
155
- function get_internal_entry (
156
- entry:: ScalingFieldMatrixEntry ,
153
+ # call get_internal_entry on scaling value, and rebuild entry container
154
+ get_internal_entry (entry:: UniformScaling , name_pair:: FieldNamePair , key_error) =
155
+ UniformScaling (
156
+ get_internal_entry (scaling_value (entry), name_pair, key_error),
157
+ )
158
+ get_internal_entry (
159
+ entry:: DiagonalMatrixRow ,
157
160
name_pair:: FieldNamePair ,
158
161
key_error,
162
+ ) = DiagonalMatrixRow (
163
+ get_internal_entry (scaling_value (entry), name_pair, key_error),
159
164
)
160
- T = eltype (entry)
165
+ # get_internal_entry to be used on the values held inside a `BandMatrixRow`
166
+ function get_internal_entry (
167
+ entry:: T ,
168
+ name_pair:: FieldNamePair ,
169
+ key_error,
170
+ ) where {T}
161
171
if name_pair == (@name (), @name ())
162
172
return entry
163
173
elseif T <: Geometry.Axis2Tensor &&
@@ -170,45 +180,39 @@ function get_internal_entry(
170
180
row_index = extract_first (internal_row_name)
171
181
col_index = extract_first (internal_col_name)
172
182
return get_internal_entry (
173
- DiagonalMatrixRow ( scaling_value ( entry) [row_index, col_index]) ,
183
+ entry[row_index, col_index],
174
184
(drop_first (internal_row_name), drop_first (internal_col_name)),
175
185
key_error,
176
186
)
177
- elseif T <: Geometry.AdjointAxisVector
187
+ elseif T <: Geometry.AdjointAxisVector # bypass parent for adjoint vectors
178
188
return get_internal_entry (
179
- DiagonalMatrixRow ( getfield (scaling_value ( entry) , :parent ) ),
189
+ getfield (entry, :parent ),
180
190
name_pair,
181
191
key_error,
182
192
)
183
- else
184
- child_name, remaining_chain =
185
- if name_pair[1 ] != @name () &&
186
- extract_first (name_pair[1 ]) in fieldnames (T)
187
- @inline (
188
- extract_first (name_pair[1 ]),
189
- (drop_first (name_pair[1 ]), name_pair[2 ]),
190
- )
191
- elseif name_pair[2 ] != @name () &&
192
- extract_first (name_pair[2 ]) in fieldnames (T)
193
- @inline (
194
- extract_first (name_pair[2 ]),
195
- (name_pair[1 ], drop_first (name_pair[2 ])),
196
- )
197
- elseif ! any (isequal (@name ()), name_pair) # implicit tensor structure
198
- return get_internal_entry (
199
- extract_first (name_pair[1 ]) == extract_first (name_pair[2 ]) ?
200
- entry : zero (entry),
201
- (drop_first (name_pair[1 ]), drop_first (name_pair[2 ])),
202
- key_error,
203
- )
204
- else
205
- throw (key_error)
206
- end
193
+ elseif name_pair[1 ] != @name () &&
194
+ extract_first (name_pair[1 ]) in fieldnames (T)
195
+ return get_internal_entry (
196
+ getfield (entry, extract_first (name_pair[1 ])),
197
+ (drop_first (name_pair[1 ]), name_pair[2 ]),
198
+ key_error,
199
+ )
200
+ elseif name_pair[2 ] != @name () &&
201
+ extract_first (name_pair[2 ]) in fieldnames (T)
202
+ return get_internal_entry (
203
+ getfield (entry, extract_first (name_pair[2 ])),
204
+ (name_pair[1 ], drop_first (name_pair[2 ])),
205
+ key_error,
206
+ )
207
+ elseif ! any (isequal (@name ()), name_pair) # implicit tensor structure
207
208
return get_internal_entry (
208
- DiagonalMatrixRow (getfield (scaling_value (entry), child_name)),
209
- remaining_chain,
209
+ extract_first (name_pair[1 ]) == extract_first (name_pair[2 ]) ? entry :
210
+ zero (entry),
211
+ (drop_first (name_pair[1 ]), drop_first (name_pair[2 ])),
210
212
key_error,
211
213
)
214
+ else
215
+ throw (key_error)
212
216
end
213
217
end
214
218
function get_internal_entry (
@@ -241,18 +245,19 @@ function get_internal_entry(
241
245
)
242
246
return Fields. Field (values, axes (entry))
243
247
elseif apply_zero
248
+ zero_value = zero (target_type)
244
249
return Base. broadcasted (entry) do matrix_row
245
250
map (matrix_row) do matrix_row_entry
246
- zero (target_type)
251
+ # zero(target_type)
252
+ zero_value
247
253
end
248
254
end
255
+ elseif target_type == S
256
+ return entry
249
257
else
250
258
return Base. broadcasted (entry) do matrix_row
251
259
map (matrix_row) do matrix_row_entry
252
- broadcasted_get_field (
253
- broadcasted_get_field (matrix_row_entry, name_pair[1 ]),
254
- name_pair[2 ],
255
- )
260
+ get_internal_entry (matrix_row_entry, name_pair, key_error)
256
261
end
257
262
end
258
263
end
@@ -329,7 +334,8 @@ function field_offset_and_type(
329
334
key_error,
330
335
) where {S, T}
331
336
name_pair == (@name (), @name ()) && return (0 , S, false ) # base case
332
- if S <: Geometry.Axis2Tensor # special case to calculate index
337
+ if S <: Geometry.Axis2Tensor &&
338
+ all (n -> is_child_name (n, @name (components. data)), name_pair)# special case to calculate index
333
339
(name_pair[1 ] == @name () || name_pair[2 ] == @name ()) && throw (key_error)
334
340
internal_row_name =
335
341
extract_internal_name (name_pair[1 ], @name (components. data))
@@ -372,7 +378,7 @@ function field_offset_and_type(
372
378
field_offset_and_type (
373
379
(drop_first (name_pair[1 ]), drop_first (name_pair[2 ])),
374
380
T,
375
- fieldtype (S, 1 ) ,
381
+ S ,
376
382
key_error,
377
383
)
378
384
return (
@@ -419,8 +425,8 @@ when indexing `dict`.
419
425
"""
420
426
function get_scalar_keys (dict:: FieldMatrix )
421
427
keys_tuple = unrolled_flatmap (keys (dict). values) do outer_key
422
- @inline unrolled_map (get_scalar_keys (eltype (dict[outer_key]))) do inner_key
423
- @inline (
428
+ unrolled_map (get_scalar_keys (eltype (dict[outer_key]))) do inner_key
429
+ (
424
430
append_internal_name (outer_key[1 ], inner_key[1 ]),
425
431
append_internal_name (outer_key[2 ], inner_key[2 ]),
426
432
)
@@ -463,17 +469,6 @@ function get_scalar_keys(::Type{T}) where {T}
463
469
inner_key[2 ],
464
470
)
465
471
end
466
- # return unrolled_flatmap((:components,)) do inner_field
467
- # @inline unrolled_map(
468
- # # get_scalar_keys(fieldtype(T, inner_field)),
469
- # get_scalar_keys(fieldtype(T, :components)),
470
- # ) do inner_key
471
- # @inline (
472
- # append_internal_name(FieldName(inner_field), inner_key[1]),
473
- # inner_key[2],
474
- # )
475
- # end
476
- # end
477
472
else
478
473
return unrolled_flatmap (fieldnames (T)) do inner_name
479
474
unrolled_map (get_scalar_keys (fieldtype (T, inner_name))) do inner_key
0 commit comments