@@ -163,31 +163,52 @@ function get_internal_entry(
163
163
elseif T <: Geometry.Axis2Tensor &&
164
164
all (n -> is_child_name (n, @name (components. data)), name_pair)
165
165
# two indices needed to index into a 2d tensor (one can be Colon())
166
- row_index = extract_first (
167
- extract_internal_name (name_pair[1 ], @name (components. data)),
166
+ internal_row_name =
167
+ extract_internal_name (name_pair[1 ], @name (components. data))
168
+ internal_col_name =
169
+ extract_internal_name (name_pair[2 ], @name (components. data))
170
+ row_index = extract_first (internal_row_name)
171
+ col_index = extract_first (internal_col_name)
172
+ return get_internal_entry (
173
+ DiagonalMatrixRow (scaling_value (entry)[row_index, col_index]),
174
+ (drop_first (internal_row_name), drop_first (internal_col_name)),
175
+ key_error,
168
176
)
169
- col_index = extract_first (
170
- extract_internal_name (name_pair[2 ], @name (components. data)),
177
+ elseif T <: Geometry.AdjointAxisVector
178
+ return get_internal_entry (
179
+ DiagonalMatrixRow (getfield (scaling_value (entry), :parent )),
180
+ name_pair,
181
+ key_error,
171
182
)
172
- return DiagonalMatrixRow (scaling_value (entry)[row_index, col_index])
173
183
else
174
- combined_chain = append_internal_name (name_pair[1 ], name_pair[2 ])
175
- modified_chain =
176
- T <: Geometry.AdjointAxisVector && # indexing adjoint of axis vector
177
- extract_first (combined_chain) != :parent ? # is the same as its parent
178
- extract_internal_name (combined_chain, @name (components. data)) :
179
- combined_chain
180
- if ! broadcasted_has_field (T, modified_chain) # implicit tensor structure case
181
- T <: Geometry.SingleValue || throw (key_error) # optimization only works with scalars
182
- hasfield (T, extract_first (combined_chain)) && throw (key_error)
183
- name_pair[1 ] == name_pair[2 ] && return entry # multiplication case 3, 4 first arg
184
- is_overlapping_name (name_pair[1 ], name_pair[2 ]) && throw (key_error)
185
- return zero (entry) # off diagonal
186
- else
187
- return DiagonalMatrixRow (
188
- broadcasted_get_field (scaling_value (entry), modified_chain),
189
- ) # multiplication case 2, 4 second arg
190
- end
184
+ child_name, remaining_chain =
185
+ if name_pair[1 ] != @name () &&
186
+ extract_first (name_pair[1 ]) in fieldnames (T)
187
+ @inline (
188
+ extract_first (name_pair[1 ]),
189
+ (drop_first (name_pair[1 ]), name_pair[2 ]),
190
+ )
191
+ elseif name_pair[2 ] != @name () &&
192
+ extract_first (name_pair[2 ]) in fieldnames (T)
193
+ @inline (
194
+ extract_first (name_pair[2 ]),
195
+ (name_pair[1 ], drop_first (name_pair[2 ])),
196
+ )
197
+ elseif ! any (isequal (@name ()), name_pair) # implicit tensor structure
198
+ return get_internal_entry (
199
+ extract_first (name_pair[1 ]) == extract_first (name_pair[2 ]) ?
200
+ entry : zero (entry),
201
+ (drop_first (name_pair[1 ]), drop_first (name_pair[2 ])),
202
+ key_error,
203
+ )
204
+ else
205
+ throw (key_error)
206
+ end
207
+ return get_internal_entry (
208
+ DiagonalMatrixRow (getfield (scaling_value (entry), child_name)),
209
+ remaining_chain,
210
+ key_error,
211
+ )
191
212
end
192
213
end
193
214
function get_internal_entry (
@@ -198,27 +219,9 @@ function get_internal_entry(
198
219
name_pair == (@name (), @name ()) && return entry
199
220
S = eltype (eltype (entry))
200
221
T = eltype (parent (entry))
201
- first_name = extract_first (append_internal_name (name_pair... ))
202
- if ! hasfield (S, first_name) &&
203
- ! (hasfield (S, :parent ) && hasfield (fieldtype (S, :parent ), first_name)) &&
204
- S <: Geometry.SingleValue
205
- if name_pair[1 ] == name_pair[2 ]
206
- return entry # multiplication case 3 or 4, first arg
207
- elseif is_overlapping_name (name_pair[1 ], name_pair[2 ])
208
- throw (key_error)
209
- else
210
- throw (key_error)
211
- return Base. broadcasted (entry) do matrix_row # off diagonal
212
- map (matrix_row) do matrix_row_entry
213
- zero (S)
214
- end
215
- end
216
- end
217
- end
218
- # multiplication case 2 or 4, second arg
219
- (index_offset, target_type) =
222
+ (start_offset, target_type, apply_zero) =
220
223
field_offset_and_type (name_pair, T, S, key_error)
221
- if target_type <: eltype (parent (entry))
224
+ if target_type <: eltype (parent (entry)) && ! apply_zero
222
225
band_element_size =
223
226
DataLayouts. typesize (eltype (parent (entry)), eltype (eltype (entry)))
224
227
singleton_datalayout = DataLayouts. singleton (Fields. field_values (entry))
@@ -227,7 +230,7 @@ function get_internal_entry(
227
230
field_dim_size = DataLayouts. ncomponents (Fields. field_values (entry))
228
231
parent_indices = DataLayouts. to_data_specific_field (
229
232
singleton_datalayout,
230
- (:, :, (index_offset + 1 ): band_element_size: field_dim_size, :, :),
233
+ (:, :, (start_offset + 1 ): band_element_size: field_dim_size, :, :),
231
234
)
232
235
scalar_data = view (parent (entry), parent_indices... )
233
236
values = DataLayouts. union_all (singleton_datalayout){
@@ -237,6 +240,12 @@ function get_internal_entry(
237
240
scalar_data,
238
241
)
239
242
return Fields. Field (values, axes (entry))
243
+ elseif apply_zero
244
+ return Base. broadcasted (entry) do matrix_row
245
+ map (matrix_row) do matrix_row_entry
246
+ zero (target_type)
247
+ end
248
+ end
240
249
else
241
250
return Base. broadcasted (entry) do matrix_row
242
251
map (matrix_row) do matrix_row_entry
@@ -248,6 +257,12 @@ function get_internal_entry(
248
257
end
249
258
end
250
259
end
260
+ if hasfield (Method, :recursion_relation )
261
+ dont_limit = (args... ) -> true
262
+ for m in methods (get_internal_entry)
263
+ m. recursion_relation = dont_limit
264
+ end
265
+ end
251
266
252
267
# Similar behavior to indexing an array with a slice.
253
268
function Base. getindex (dict:: FieldNameDict , new_keys:: FieldNameSet )
@@ -313,40 +328,68 @@ function field_offset_and_type(
313
328
:: Type{S} ,
314
329
key_error,
315
330
) where {S, T}
316
- name_pair == (@name (), @name ()) && return (0 , S) # base case
317
- if S <: Geometry.Axis2Tensor{T} # special case to calculate index
331
+ name_pair == (@name (), @name ()) && return (0 , S, false ) # base case
332
+ if S <: Geometry.Axis2Tensor # special case to calculate index
318
333
(name_pair[1 ] == @name () || name_pair[2 ] == @name ()) && throw (key_error)
319
- row_index = extract_first (
320
- extract_internal_name (name_pair[1 ], @name (components. data)),
321
- )
322
- col_index = extract_first (
323
- extract_internal_name (name_pair[ 2 ], @name (components . data)),
324
- )
334
+ internal_row_name =
335
+ extract_internal_name (name_pair[1 ], @name (components. data))
336
+ internal_col_name =
337
+ extract_internal_name (name_pair[ 2 ], @name (components . data))
338
+ row_index = extract_first (internal_row_name)
339
+ col_index = extract_first (internal_col_name )
325
340
((row_index isa Number) && (col_index isa Number)) || throw (key_error) # slicing not supported
326
341
(n_rows, n_cols) = map (length, axes (S))
342
+ (remaining_offset, end_type, apply_zero) = field_offset_and_type (
343
+ (drop_first (internal_row_name), drop_first (internal_col_name)),
344
+ T,
345
+ eltype (S),
346
+ key_error,
347
+ )
327
348
(row_index <= n_rows && col_index <= n_cols) || throw (key_error)
328
- return (n_rows * (col_index - 1 ) + row_index - 1 , T)
349
+ return (
350
+ (n_rows * (col_index - 1 ) + row_index - 1 ) + remaining_offset,
351
+ end_type,
352
+ apply_zero,
353
+ )
354
+ elseif S <: Geometry.AdjointAxisVector
355
+ return field_offset_and_type (name_pair, T, fieldtype (S, 1 ), key_error)
329
356
else
330
- child_name, remaining_field_chain = if name_pair[1 ] != @name ()
331
- extract_first (name_pair[1 ]), (drop_first (name_pair[1 ]), name_pair[2 ])
332
- else
333
- extract_first (name_pair[2 ]), (@name (), drop_first (name_pair[2 ]))
334
- end
335
- # indexing adjoint of axis vector is the same as indexing its parent
336
- (S <: Geometry.AdjointAxisVector && child_name != :parent ) &&
337
- return field_offset_and_type (
338
- name_pair,
339
- T,
340
- fieldtype (S, 1 ),
341
- key_error,
342
- )
343
- (child_name in fieldnames (S)) || throw (key_error)
357
+ child_name, remaining_field_chain =
358
+ if name_pair[1 ] != @name () &&
359
+ extract_first (name_pair[1 ]) in fieldnames (S)
360
+ @inline (
361
+ extract_first (name_pair[1 ]),
362
+ (drop_first (name_pair[1 ]), name_pair[2 ]),
363
+ )
364
+ elseif name_pair[2 ] != @name () &&
365
+ extract_first (name_pair[2 ]) in fieldnames (S)
366
+ @inline (
367
+ extract_first (name_pair[2 ]),
368
+ (name_pair[1 ], drop_first (name_pair[2 ])),
369
+ )
370
+ elseif ! any (isequal (@name ()), name_pair) # implicit tensor structure
371
+ (remaining_offset, end_type, apply_zero) =
372
+ field_offset_and_type (
373
+ (drop_first (name_pair[1 ]), drop_first (name_pair[2 ])),
374
+ T,
375
+ fieldtype (S, 1 ),
376
+ key_error,
377
+ )
378
+ return (
379
+ remaining_offset,
380
+ end_type,
381
+ extract_first (name_pair[1 ]) == extract_first (name_pair[2 ]) ?
382
+ apply_zero : true ,
383
+ )
384
+ else
385
+ throw (key_error)
386
+ end
344
387
child_type = fieldtype (S, child_name)
345
388
field_index = unrolled_filter (
346
389
i -> fieldname (S, i) == child_name,
347
390
1 : fieldcount (S),
348
391
)[1 ]
349
- (remaining_offset, end_type) = field_offset_and_type (
392
+ (remaining_offset, end_type, apply_zero ) = field_offset_and_type (
350
393
remaining_field_chain,
351
394
T,
352
395
child_type,
@@ -355,7 +398,9 @@ function field_offset_and_type(
355
398
return (
356
399
DataLayouts. fieldtypeoffset (T, S, field_index) + remaining_offset,
357
400
end_type,
401
+ apply_zero,
358
402
)
403
+
359
404
end
360
405
end
361
406
if hasfield (Method, :recursion_relation )
@@ -366,16 +411,16 @@ if hasfield(Method, :recursion_relation)
366
411
end
367
412
368
413
"""
369
- get_scalar_keys(dict::FieldMatrix, FT )
414
+ get_scalar_keys(dict::FieldMatrix)
370
415
371
416
Returns a `FieldMatrixKeys` object that contains the keys that result in
372
- a `ScalingFieldMatrixEntry{FT }` or a `ColumnwiseBandMatrixField` with bands of eltype `FT `
417
+ a `ScalingFieldMatrixEntry{<:Number }` or a `ColumnwiseBandMatrixField` with bands of eltype `< :Number `
373
418
when indexing `dict`.
374
419
"""
375
- function get_scalar_keys (dict:: FieldMatrix , :: Type{FT} ) where {FT}
420
+ function get_scalar_keys (dict:: FieldMatrix )
376
421
keys_tuple = unrolled_flatmap (keys (dict). values) do outer_key
377
- unrolled_map (get_scalar_keys (eltype (dict[outer_key]), FT )) do inner_key
378
- (
422
+ @inline unrolled_map (get_scalar_keys (eltype (dict[outer_key]))) do inner_key
423
+ @inline (
379
424
append_internal_name (outer_key[1 ], inner_key[1 ]),
380
425
append_internal_name (outer_key[2 ], inner_key[2 ]),
381
426
)
@@ -385,16 +430,16 @@ function get_scalar_keys(dict::FieldMatrix, ::Type{FT}) where {FT}
385
430
end
386
431
387
432
"""
388
- get_scalar_keys(T::Type, FT::Type )
433
+ get_scalar_keys(T::Type)
389
434
390
435
Returns a tuple of `FieldNamePair` objects that correspond to any children
391
436
of `T` that are of type `FT`.
392
437
"""
393
- function get_scalar_keys (:: Type{T} , :: Type{FT} ) where {T, FT }
394
- if T <: FT || T <: Bool # identity has eltype Bool
438
+ function get_scalar_keys (:: Type{T} ) where {T}
439
+ if T <: Number # TODO : is this tight enough of a Type? what about complex and duals and plushalfs
395
440
return ((@name (), @name ()),)
396
441
elseif T <: BandMatrixRow
397
- return get_scalar_keys (eltype (T), FT )
442
+ return get_scalar_keys (eltype (T))
398
443
elseif T <: Geometry.Axis2Tensor
399
444
return unrolled_flatmap (1 : length (axes (T)[1 ])) do row_component
400
445
unrolled_map (1 : length (axes (T)[2 ])) do col_component
@@ -405,28 +450,35 @@ function get_scalar_keys(::Type{T}, ::Type{FT}) where {T, FT}
405
450
end
406
451
end
407
452
elseif T <: Geometry.AdjointAxisVector
408
- return unrolled_map (
409
- get_scalar_keys (fieldtype (T, :parent ), FT),
410
- ) do inner_key
453
+ return unrolled_map (get_scalar_keys (fieldtype (T, :parent ))) do inner_key
411
454
(inner_key[2 ], inner_key[1 ]) # assumes that adjoints only appear with d/dvec
412
455
end
413
456
elseif T <: Geometry.AxisVector # special case to avoid recursing into the axis field
414
457
# TODO : this should be able to be combined with the else case, but it causes runtime dispatch
415
458
return unrolled_map (
416
- get_scalar_keys (fieldtype (T, :components ), FT ),
459
+ get_scalar_keys (fieldtype (T, :components )),
417
460
) do inner_key
418
461
(
419
462
append_internal_name (@name (components), inner_key[1 ]),
420
463
inner_key[2 ],
421
464
)
422
465
end
466
+ # return unrolled_flatmap((:components,)) do inner_field
467
+ # @inline unrolled_map(
468
+ # # get_scalar_keys(fieldtype(T, inner_field)),
469
+ # get_scalar_keys(fieldtype(T, :components)),
470
+ # ) do inner_key
471
+ # @inline (
472
+ # append_internal_name(FieldName(inner_field), inner_key[1]),
473
+ # inner_key[2],
474
+ # )
475
+ # end
476
+ # end
423
477
else
424
- return unrolled_flatmap (fieldnames (T)) do inner_field
425
- unrolled_map (
426
- get_scalar_keys (fieldtype (T, inner_field), FT),
427
- ) do inner_key
478
+ return unrolled_flatmap (fieldnames (T)) do inner_name
479
+ unrolled_map (get_scalar_keys (fieldtype (T, inner_name))) do inner_key
428
480
(
429
- append_internal_name (FieldName (inner_field ), inner_key[1 ]),
481
+ append_internal_name (FieldName (inner_name ), inner_key[1 ]),
430
482
inner_key[2 ],
431
483
)
432
484
end
442
494
443
495
444
496
"""
445
- scalar_fieldmatrix(field_matrix::FieldMatrix, FT )
497
+ scalar_fieldmatrix(field_matrix::FieldMatrix)
446
498
447
499
Constructs a `FieldNameDict` where the keys and entries are views
448
500
of the entries of `field_matrix`, which corresponding to the
@@ -464,7 +516,7 @@ A = MatrixFields.FieldMatrix(
464
516
(@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar,
465
517
)
466
518
467
- A_scalar = MatrixFields.scalar_fieldmatrix(A, Float64 )
519
+ A_scalar = MatrixFields.scalar_fieldmatrix(A)
468
520
keys(A_scalar)
469
521
# Output:
470
522
# (@name(c.ρχ.ρq_liq), @name(f.u₃.:(1)))
@@ -473,8 +525,8 @@ keys(A_scalar)
473
525
# (@name(c.uₕ.:(2)), @name(c.sgsʲs.:(1).ρa))
474
526
```
475
527
"""
476
- function scalar_fieldmatrix (field_matrix:: FieldMatrix , :: Type{FT} ) where {FT}
477
- scalar_keys = get_scalar_keys (field_matrix, FT )
528
+ function scalar_fieldmatrix (field_matrix:: FieldMatrix )
529
+ scalar_keys = get_scalar_keys (field_matrix)
478
530
entries = unrolled_map (scalar_keys. values) do key
479
531
field_matrix[key]
480
532
end
0 commit comments