@@ -181,11 +181,51 @@ function get_internal_entry(
181
181
entry
182
182
elseif name_pair[2 ] == @name () && broadcasted_has_field (T, name_pair[1 ])
183
183
# multiplication case 2 or 4, second argument
184
- Base. broadcasted (entry) do matrix_row
185
- map (matrix_row) do matrix_row_entry
186
- broadcasted_get_field (matrix_row_entry, name_pair[1 ])
187
- end
188
- end # Note: This assumes that the entry is in a FieldMatrixBroadcasted.
184
+ target_field_eltype = broadcasted_get_field_type (T, name_pair[1 ])
185
+ if target_field_eltype <: Number
186
+ T_band = eltype (entry)
187
+ singleton_datalayout =
188
+ DataLayouts. singleton (Fields. field_values (entry))
189
+ # BandMatrixRow with same lowest diagonal and bandwidth as `entry`, with a scalar eltype
190
+ scalar_band_type = BandMatrixRow{
191
+ T_band. parameters[1 ],
192
+ T_band. parameters[2 ],
193
+ eltype (parent (entry)),
194
+ }
195
+ field_dim_size = DataLayouts. ncomponents (Fields. field_values (entry))
196
+ scalar_field_offset = get_field_first_index_offset (
197
+ name_pair[1 ],
198
+ target_field_eltype,
199
+ T,
200
+ )
201
+ band_element_size = Int (div (sizeof (T), sizeof (target_field_eltype)))
202
+ parent_indices = DataLayouts. to_data_specific_field (
203
+ singleton_datalayout,
204
+ (
205
+ :,
206
+ :,
207
+ (1 + scalar_field_offset): band_element_size: field_dim_size,
208
+ :,
209
+ :,
210
+ ),
211
+ )
212
+ scalar_data = view (parent (entry), parent_indices... )
213
+ values = DataLayouts. union_all (singleton_datalayout){
214
+ scalar_band_type,
215
+ Base. tail (
216
+ DataLayouts. type_params (Fields. field_values (entry)),
217
+ )... ,
218
+ }(
219
+ scalar_data,
220
+ )
221
+ Fields. Field (values, axes (entry))
222
+ else
223
+ Base. broadcasted (entry) do matrix_row
224
+ map (matrix_row) do matrix_row_entry
225
+ broadcasted_get_field (matrix_row_entry, name_pair[1 ])
226
+ end
227
+ end # Note: This assumes that the entry is in a FieldMatrixBroadcasted.
228
+ end
189
229
else
190
230
throw (key_error)
191
231
end
@@ -237,6 +277,74 @@ function Base.one(matrix::FieldMatrix)
237
277
return FieldNameDict (inferred_diagonal_keys, entries)
238
278
end
239
279
280
+ """
281
+ get_field_first_index_offset(name::FieldName, ::Type{T}, ::Type{S})
282
+
283
+ Returns the offset of the the field with name `name` in an object of type `S`
284
+ in multiples of `sizeof(T)`.
285
+ """
286
+ function get_field_first_index_offset (
287
+ name:: FieldName ,
288
+ :: Type{T} ,
289
+ :: Type{S} ,
290
+ ) where {T, S}
291
+ if name == @name ()
292
+ return 0
293
+ end
294
+ child_name = extract_first (name)
295
+ child_type = fieldtype (S, child_name)
296
+ remaining_field_chain = drop_first (name)
297
+ field_index =
298
+ unrolled_filter (i -> fieldname (S, i) == child_name, 1 : fieldcount (S))[1 ]
299
+ return DataLayouts. fieldtypeoffset (T, S, field_index) +
300
+ get_field_first_index_offset (remaining_field_chain, T, child_type)
301
+ end
302
+ if hasfield (Method, :recursion_relation )
303
+ dont_limit = (args... ) -> true
304
+ for m in methods (get_field_first_index_offset)
305
+ m. recursion_relation = dont_limit
306
+ end
307
+ end
308
+
309
+ """
310
+ get_scalar_keys(dict::FieldMatrix)
311
+
312
+ Returns a `FieldMatrixKeys` object that contains the keys of all the scalar
313
+ entries in the `FieldMatrix` `dict`.
314
+ """
315
+ function get_scalar_keys (dict:: FieldMatrix )
316
+ keys_tuple = unrolled_flatmap (keys (dict). values) do key
317
+ _, entry = unrolled_filter (pair -> key == pair[1 ], pairs (dict))[1 ]
318
+ entry =
319
+ entry isa ColumnwiseBandMatrixField ? entry. entries.:(1 ) : entry
320
+ unrolled_map (
321
+ filtered_child_names (
322
+ field -> eltype (field) <: Number ,
323
+ entry,
324
+ @name ()
325
+ ),
326
+ ) do name
327
+ (append_internal_name (key[1 ], name), key[2 ])
328
+ end
329
+ end
330
+ return FieldMatrixKeys (keys_tuple)
331
+ end
332
+
333
+ """
334
+ scalar_fieldmatrix(field_matrix::FieldMatrix)
335
+
336
+ Constructs a `FieldNameDict` where the keys and entries are views
337
+ of the entries of `field_matrix`, which corresponding to the
338
+ scalar components of entries of `field_matrix`.
339
+ """
340
+ function scalar_fieldmatrix (field_matrix:: FieldMatrix )
341
+ scalar_keys = get_scalar_keys (field_matrix)
342
+ entries = unrolled_map (scalar_keys. values) do key
343
+ field_matrix[key]
344
+ end
345
+ return FieldNameDict (scalar_keys, entries)
346
+ end
347
+
240
348
replace_name_tree (dict:: FieldNameDict , name_tree) =
241
349
FieldNameDict (replace_name_tree (keys (dict), name_tree), values (dict))
242
350
0 commit comments