@@ -165,10 +165,14 @@ function get_internal_entry(
165
165
name_pair:: FieldNamePair ,
166
166
key_error,
167
167
)
168
+ # TODO : Add other cases
168
169
if name_pair[1 ] == name_pair[2 ]
169
170
entry
170
- elseif name_pair[2 ] == @name () && has_field (entry, name_pair[1 ])
171
- DiagonalMatrixRow (get_field (entry, name_pair[1 ]))
171
+ elseif name_pair[2 ] == @name () &&
172
+ broadcasted_has_field (eltype (entry), name_pair[1 ])
173
+ DiagonalMatrixRow (
174
+ broadcasted_get_field (entry. entries.:(1 ), name_pair[1 ]),
175
+ )
172
176
elseif is_overlapping_name (name_pair[1 ], name_pair[2 ])
173
177
throw (key_error)
174
178
else
@@ -185,13 +189,32 @@ function get_internal_entry(
185
189
T = eltype (eltype (entry))
186
190
if name_pair == (@name (), @name ())
187
191
entry
188
- elseif name_pair[1 ] == name_pair[2 ] && ! broadcasted_has_field (T, name_pair[1 ])
189
- # multiplication case 3 or 4, first argument
190
- @assert T <: Geometry.SingleValue
192
+ elseif name_pair[1 ] == name_pair[2 ] &&
193
+ ! broadcasted_has_field (T, name_pair[1 ])
194
+ # @show "aa"
195
+ # # multiplication case 3 or 4, first argument
196
+ @assert T <: Number
191
197
entry
192
- elseif name_pair[2 ] == @name () && broadcasted_has_field (T, name_pair[1 ])
193
- # multiplication case 2 or 4, second argument
194
- target_field_eltype = broadcasted_get_field_type (T, name_pair[1 ])
198
+ elseif name_pair[1 ] == @name () || name_pair[2 ] == @name ()
199
+
200
+ target_chain = if name_pair[1 ] == @name ()
201
+ if broadcasted_has_field (T, name_pair[2 ])
202
+ # this case should be dscalar/dvec with T isa vec
203
+ name_pair[2 ]
204
+ else
205
+ # this should be dscalar/dvec with T isa adjoint
206
+ append_internal_name (@name (parent), name_pair[2 ])
207
+ end
208
+ else
209
+ if broadcasted_has_field (T, name_pair[1 ])
210
+ # this case should be dtuple/dscalar or dvec/dscalar with T isa vec
211
+ name_pair[1 ]
212
+ else
213
+ # this should be dvec/dscalar with T isa adjoint
214
+ append_internal_name (@name (parent), name_pair[1 ])
215
+ end
216
+ end
217
+ target_field_eltype = broadcasted_get_field_type (T, target_chain)
195
218
if target_field_eltype == eltype (parent (entry))
196
219
T_band = eltype (entry)
197
220
singleton_datalayout =
@@ -203,7 +226,7 @@ function get_internal_entry(
203
226
)
204
227
field_dim_size = DataLayouts. ncomponents (Fields. field_values (entry))
205
228
scalar_field_offset = get_field_first_index_offset (
206
- name_pair[ 1 ] ,
229
+ target_chain ,
207
230
target_field_eltype,
208
231
T,
209
232
)
@@ -231,57 +254,131 @@ function get_internal_entry(
231
254
else
232
255
Base. broadcasted (entry) do matrix_row
233
256
map (matrix_row) do matrix_row_entry
234
- broadcasted_get_field (matrix_row_entry, name_pair[ 1 ] )
257
+ broadcasted_get_field (matrix_row_entry, target_chain )
235
258
end
236
- end # Note: This assumes that the entry is in a FieldMatrixBroadcasted.
259
+ end
260
+ end
261
+ elseif name_pair[2 ] != @name () && name_pair[1 ] != @name ()
262
+ # this should only be the case with dvec/dvec or dNTuple/dvec
263
+ if T <: Geometry.SingleValue
264
+ # @assert drop_last(name_pair[1]) ==
265
+ # drop_last(name_pair[2]) ==
266
+ # @name(components.data)
267
+ row_index = extract_last (name_pair[1 ])
268
+ col_index = extract_last (name_pair[2 ])
269
+ (n_rows, n_cols) = map (length, axes (T))
270
+ @assert row_index <= n_rows && col_index <= n_cols
271
+ flattened_index = n_rows * (col_index - 1 ) + row_index
272
+ elseif eltype (T) <: Geometry.SingleValue # TODO : nested tuples?
273
+ # @assert drop_last(name_pair[2]) == @name(components.data)
274
+ modified_first_name =
275
+ broadcasted_has_field (T, name_pair[1 ]) ? name_pair[1 ] :
276
+ append_internal_name (@name (parent), name_pair[1 ])
277
+ flattened_index =
278
+ get_field_first_index_offset (
279
+ name_pair[1 ],
280
+ broadcasted_get_field_type (T, name_pair[1 ]),
281
+ T,
282
+ ) + extract_last (name_pair[2 ])
283
+ else
284
+ error (" Cannot get entry for key $name_pair " )
237
285
end
238
- elseif broadcasted_has_field (T, name_pair[1 ]) && broadcasted_has_field (T, name_pair[2 ])
239
- # this should only be the case when both independent and dependent var are axisvectors
240
- @assert T <: Geometry.SingleValue && ! (T <: Number )
241
- @assert drop_last (name_pair[1 ]) == drop_last (name_pair[2 ]) == @name (components. data)
242
- row_index = extract_last (name_pair[1 ])
243
- col_index = extract_last (name_pair[2 ])
244
- (n_rows, n_cols) = map (length, axes (T))
245
- @assert row_index <= n_rows && col_index <= n_cols
246
- flattened_index = n_rows * (col_index - 1 ) + row_index
247
286
band_element_size = div (sizeof (T), sizeof (eltype (T)))
248
287
T_band = eltype (entry)
249
- singleton_datalayout =
250
- DataLayouts. singleton (Fields. field_values (entry))
288
+ singleton_datalayout = DataLayouts. singleton (Fields. field_values (entry))
251
289
# BandMatrixRow with same lowest diagonal and bandwidth as `entry`, with a scalar eltype
252
- scalar_band_type = band_matrix_row_type (
253
- outer_diagonals (T_band)... ,
254
- eltype (T),
255
- )
290
+ scalar_band_type =
291
+ band_matrix_row_type (outer_diagonals (T_band)... , eltype (eltype (T)))
256
292
field_dim_size = DataLayouts. ncomponents (Fields. field_values (entry))
257
293
band_element_size = div (sizeof (T), sizeof (eltype (T)))
258
- @assert band_element_size == n_rows * n_cols
259
294
parent_indices = DataLayouts. to_data_specific_field (
260
- singleton_datalayout,
261
- (
262
- :,
263
- :,
264
- flattened_index: band_element_size: field_dim_size,
265
- :,
266
- :,
267
- ),
268
- )
269
- # Main.@infiltrate
295
+ singleton_datalayout,
296
+ (:, :, flattened_index: band_element_size: field_dim_size, :, :),
297
+ )
298
+
270
299
scalar_data = view (parent (entry), parent_indices... )
271
- values = DataLayouts. union_all (singleton_datalayout){
272
- scalar_band_type,
273
- Base. tail (
274
- DataLayouts. type_params (Fields. field_values (entry)),
275
- )... ,
276
- }(
277
- scalar_data,
278
- )
279
- Fields. Field (values, axes (entry))
300
+
301
+ values = DataLayouts. union_all (singleton_datalayout){
302
+ scalar_band_type,
303
+ Base. tail (DataLayouts. type_params (Fields. field_values (entry)))... ,
304
+ }(
305
+ scalar_data,
306
+ )
307
+ Fields. Field (values, axes (entry))
280
308
else
281
309
throw (key_error)
282
310
end
283
311
end
284
312
313
+ function get_scalar_keys (dict:: FieldMatrix )
314
+ keys_tuple = unrolled_flatmap (keys (dict). values) do key
315
+ entry = dict[unrolled_filter (isequal (key), keys (dict). values)[1 ]]
316
+ entry =
317
+ entry isa ColumnwiseBandMatrixField ? entry. entries.:(1 ) : entry
318
+ unrolled_map (filtered_names (entry) do field
319
+ if field isa UniformScaling
320
+ true
321
+ elseif field isa Fields. Field
322
+ eltype (field) == eltype (eltype (field))
323
+ else
324
+ eltype (field) == typeof (field)
325
+ end
326
+ end ) do name
327
+ (append_internal_name (key[1 ], name), key[2 ])
328
+ end
329
+ end
330
+ return FieldMatrixKeys (keys_tuple)
331
+ end
332
+ # function combine_name_pair(name_pair::Tuple{FieldName, FieldName{()}}, T)
333
+ # end
334
+
335
+ # function combine_name_pair(name_pair::Tuple{FieldName{()}, FieldName}, ::Type{T}) where {T}
336
+ # T <: NamedTuple && error("Cannot return ")
337
+ # # @assert eltype()
338
+ # end
339
+
340
+ # function combine_name_pair(name_pair::FieldNamePair, T)
341
+ # end
342
+
343
+ # function foobarbaz(combined_name_chain, T, entry, target_field_eltype)
344
+ # band_element_size = div(sizeof(T), sizeof(eltype(T)))
345
+ # T_band = eltype(entry)
346
+ # singleton_datalayout =
347
+ # DataLayouts.singleton(Fields.field_values(entry))
348
+ # # BandMatrixRow with same lowest diagonal and bandwidth as `entry`, with a scalar eltype
349
+ # scalar_band_type = band_matrix_row_type(
350
+ # outer_diagonals(T_band)...,
351
+ # eltype(T),
352
+ # )
353
+ # field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry))
354
+ # band_element_size = div(sizeof(T), sizeof(eltype(target_field_eltype)))
355
+ # first_index = get_field_first_index_offset(
356
+ # combined_name_chain,
357
+ # target_field_eltype,
358
+ # T,
359
+ # )
360
+ # parent_indices = DataLayouts.to_data_specific_field(
361
+ # singleton_datalayout,
362
+ # (
363
+ # :,
364
+ # :,
365
+ # first_index:band_element_size:field_dim_size,
366
+ # :,
367
+ # :,
368
+ # ),
369
+ # )
370
+ # target_data = view(parent(entry), parent_indices...)
371
+ # values = DataLayouts.union_all(singleton_datalayout){
372
+ # target_data,
373
+ # Base.tail(
374
+ # DataLayouts.type_params(Fields.field_values(entry)),
375
+ # )...,
376
+ # }(
377
+ # scalar_data,
378
+ # )
379
+ # Fields.Field(values, axes(entry))
380
+ # end
381
+
285
382
# Similar behavior to indexing an array with a slice.
286
383
function Base. getindex (dict:: FieldNameDict , new_keys:: FieldNameSet )
287
384
common_keys = intersect (keys (dict), new_keys)
@@ -368,10 +465,15 @@ function get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector)
368
465
entry = dict[unrolled_filter (isequal (key), keys (dict). values)[1 ]]
369
466
if entry isa UniformScaling # uniformscalings can only contain numbers
370
467
(key,)
371
- elseif entry isa ColumnwiseBandMatrixField
468
+ elseif entry isa ColumnwiseBandMatrixField ||
469
+ entry isa DiagonalMatrixRow
372
470
first_band = entry. entries.:(1 )
373
471
target_eltype = eltype (parent (first_band))
374
- if eltype (first_band) == target_eltype
472
+ if entry isa ColumnwiseBandMatrixField &&
473
+ eltype (first_band) <: target_eltype
474
+ (key,)
475
+ elseif entry isa DiagonalMatrixRow &&
476
+ typeof (first_band) <: target_eltype
375
477
(key,)
376
478
else
377
479
dependent_var = get_field (Y, key[1 ])
@@ -380,52 +482,55 @@ function get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector)
380
482
independent_type = eltype (independent_var)
381
483
# @Main.infiltrate
382
484
@assert dependent_type <: Geometry.SingleValue ||
383
- independent_type <: Geometry.SingleValue ||
384
- " cannot get scalar keys for key $key "
485
+ independent_type <: Geometry.SingleValue ||
486
+ " cannot get scalar keys for key $key "
385
487
386
488
# figure out if we need to drill into key[1] or key[2], or both
387
489
# @show key
388
- unrolled_flatmap (filtered_names (x -> eltype (x) <: target_eltype , dependent_var)) do dependent_name
389
- unrolled_map (filtered_names (x -> eltype (x) <: target_eltype , independent_var)) do independent_name
390
- (append_internal_name (key[1 ], dependent_name), append_internal_name (key[2 ], independent_name))
490
+ unrolled_flatmap (
491
+ filtered_names (
492
+ x -> eltype (x) <: target_eltype ,
493
+ dependent_var,
494
+ ),
495
+ ) do dependent_name
496
+ unrolled_map (
497
+ filtered_names (
498
+ x -> eltype (x) <: target_eltype ,
499
+ independent_var,
500
+ ),
501
+ ) do independent_name
502
+ (
503
+ append_internal_name (key[1 ], dependent_name),
504
+ append_internal_name (key[2 ], independent_name),
505
+ )
391
506
end
392
507
# @Main.infiltrate
393
508
# key
394
509
end
395
510
# (key,)
396
511
end
512
+ # elseif entry isa DiagonalMatrixRow
513
+ # target_eltype = eltype(parent(get_field(Y, key[1])))
514
+ # # TODO : unify target_eltype
515
+ # (key,)
397
516
else
398
- # TODO : Fix me
399
- (key,)
517
+ error (" Cannot get scalar keys for key $key " )
400
518
end
401
519
402
- # entry =
403
- # entry isa ColumnwiseBandMatrixField ? entry.entries.:(1) : entry
404
- # unrolled_map(filtered_names(entry) do field
405
- # if field isa UniformScaling
406
- # true
407
- # elseif field isa Fields.Field
408
- # eltype(field) == eltype(eltype(field))
409
- # else
410
- # eltype(field) == typeof(field)
411
- # end
412
- # end) do name
413
- # (append_internal_name(key[1], name), key[2])
414
- # end
415
520
end
416
521
return FieldMatrixKeys (keys_tuple)
417
522
end
418
523
419
- function new_get_scalar_keys (dict:: FieldMatrix , Y:: Fields.FieldVector )
420
- scalar_field_vector_keys = MatrixFields. filtered_names (Y) do field
421
- field isa Fields. Field && eltype (field) == eltype (parent (field))
422
- end
423
- map (keys (dict). values) do key
424
- first_key_is_scalar = unrolled_any (isequal (key[1 ]), scalar_field_vector_keys)
425
- second_key_is_scalar = unrolled_any (isequal (key[2 ]), scalar_field_vector_keys)
426
- @assert first_key_is_scalar || second_key_is_scalar " $key "
427
- end
428
- end
524
+ # function new_get_scalar_keys(dict::FieldMatrix, Y::Fields.FieldVector)
525
+ # scalar_field_vector_keys = MatrixFields.filtered_names(Y) do field
526
+ # field isa Fields.Field && eltype(field) == eltype(parent(field))
527
+ # end
528
+ # map(keys(dict).values) do key
529
+ # first_key_is_scalar = unrolled_any(isequal(key[1]), scalar_field_vector_keys)
530
+ # second_key_is_scalar = unrolled_any(isequal(key[2]), scalar_field_vector_keys)
531
+ # @assert first_key_is_scalar || second_key_is_scalar "$key"
532
+ # end
533
+ # end
429
534
430
535
"""
431
536
scalar_fieldmatrix(field_matrix::FieldMatrix)
@@ -467,6 +572,14 @@ function scalar_fieldmatrix(field_matrix::FieldMatrix)
467
572
return FieldNameDict (scalar_keys, entries)
468
573
end
469
574
575
+ function scalar_fieldmatrix (field_matrix:: FieldMatrix , Y:: Fields.FieldVector )
576
+ scalar_keys = get_scalar_keys (field_matrix, Y)
577
+ entries = unrolled_map (scalar_keys. values) do key
578
+ field_matrix[key]
579
+ end
580
+ return FieldNameDict (scalar_keys, entries)
581
+ end
582
+
470
583
replace_name_tree (dict:: FieldNameDict , name_tree) =
471
584
FieldNameDict (replace_name_tree (keys (dict), name_tree), values (dict))
472
585
@@ -776,8 +889,8 @@ function Base.Broadcast.broadcasted(
776
889
)
777
890
product_value = scaling_value (entry1) * scaling_value (entry2)
778
891
product_value isa Number ?
779
- UniformScaling (product_value) :
780
- DiagonalMatrixRow (product_value)
892
+ ( UniformScaling (product_value), ) :
893
+ ( DiagonalMatrixRow (product_value), )
781
894
elseif entry1 isa ScalingFieldMatrixEntry
782
895
Base. Broadcast. broadcasted (* , (scaling_value (entry1),), entry2)
783
896
elseif entry2 isa ScalingFieldMatrixEntry
0 commit comments