@@ -152,62 +152,42 @@ get_internal_key(child_name_pair::FieldNamePair, name_pair::FieldNamePair) = (
152
152
get_internal_entry (entry, name:: FieldName , key_error) = get_field (entry, name)
153
153
get_internal_entry (entry, name_pair:: FieldNamePair , key_error) =
154
154
name_pair == (@name (), @name ()) ? entry : throw (key_error)
155
- get_internal_entry (entry:: UniformScaling , name_pair:: FieldNamePair , key_error) =
156
- if name_pair[1 ] == name_pair[2 ]
157
- entry
158
- elseif is_overlapping_name (name_pair[1 ], name_pair[2 ])
159
- throw (key_error)
160
- else
161
- zero (entry)
162
- end
163
155
function get_internal_entry (
164
- entry:: DiagonalMatrixRow{T} ,
156
+ entry:: ScalingFieldMatrixEntry ,
165
157
name_pair:: FieldNamePair ,
166
158
key_error,
167
- ) where {T}
159
+ )
160
+ T = eltype (entry)
168
161
if name_pair == (@name (), @name ())
169
162
return entry
170
- elseif T <: Geometry.Axis2Tensor
171
- (name_pair[1 ] == @name () || name_pair[2 ] == @name ()) && throw (key_error) # Cannot slice a 2D tensor
172
- row_index = extract_first (name_pair[1 ])
173
- col_index = extract_first (name_pair[2 ])
174
- if (! (row_index isa Number) || ! (col_index isa Number))
175
- name_pair[1 ] == name_pair[2 ] && return entry # multiplication case 3 or 4, first argument
176
- is_overlapping_name (name_pair[1 ], name_pair[2 ]) ||
177
- return zero (entry)
178
- throw (key_error)
179
- end
180
- (n_rows, n_cols) = map (length, axes (scaling_value (entry)))
181
- @assert row_index <= n_rows && col_index <= n_cols
182
- return DiagonalMatrixRow (
183
- broadcasted_get_field (
184
- scaling_value (entry). components. data,
185
- FieldName (n_rows * (col_index - 1 ) + row_index),
186
- ),
163
+ elseif T <: Geometry.Axis2Tensor &&
164
+ all (n -> is_child_name (n, @name (components. data)), name_pair)
165
+ # two indices needed to index into a 2d tensor (one can be Colon())
166
+ row_index = extract_first (
167
+ extract_internal_name (name_pair[1 ], @name (components. data)),
187
168
)
188
- elseif T <: Geometry.AxisVectorOrAdj
189
- return get_internal_entry (
190
- DiagonalMatrixRow (
191
- getfield (
192
- getfield (
193
- T <: Geometry.AxisVector ? scaling_value (entry) :
194
- scaling_value (entry). parent,
195
- :components ,
196
- ),
197
- :data ,
198
- ),
199
- ),
200
- name_pair,
201
- key_error,
169
+ col_index = extract_first (
170
+ extract_internal_name (name_pair[2 ], @name (components. data)),
202
171
)
203
- elseif name_pair[1 ] == name_pair[2 ] &&
204
- ! broadcasted_has_field (eltype (entry), name_pair[1 ])
205
- return entry
172
+ return DiagonalMatrixRow (scaling_value (entry)[row_index, col_index])
206
173
else
207
- non_empty_chain = name_pair[1 ] != @name () ? name_pair[1 ] : name_pair[2 ]
208
- return DiagonalMatrixRow (
209
- broadcasted_get_field (scaling_value (entry), non_empty_chain),
210
- )
174
+ combined_chain = append_internal_name (name_pair[1 ], name_pair[2 ])
175
+ modified_chain =
176
+ T <: Geometry.AdjointAxisVector && # indexing adjoint of axis vector
177
+ extract_first (combined_chain) != :parent ? # is the same as its parent
178
+ extract_internal_name (combined_chain, @name (components. data)) :
179
+ combined_chain
180
+ if ! broadcasted_has_field (T, modified_chain) # implicit tensor structure case
181
+ T <: Geometry.SingleValue || throw (key_error) # optimization only works with scalars
182
+ hasfield (T, extract_first (combined_chain)) && throw (key_error)
183
+ name_pair[1 ] == name_pair[2 ] && return entry # multiplication case 3, 4 first arg
184
+ is_overlapping_name (name_pair[1 ], name_pair[2 ]) && throw (key_error)
185
+ return zero (entry) # off diagonal
186
+ else
187
+ return DiagonalMatrixRow (
188
+ broadcasted_get_field (scaling_value (entry), modified_chain),
189
+ ) # multiplication case 2, 4 second arg
190
+ end
211
191
end
212
192
end
213
193
function get_internal_entry (
@@ -216,13 +196,28 @@ function get_internal_entry(
216
196
key_error,
217
197
)
218
198
name_pair == (@name (), @name ()) && return entry
219
- (index_offset, target_type) = field_offset_and_type (
220
- name_pair,
221
- eltype (parent (entry)),
222
- eltype (eltype (entry)),
223
- key_error,
224
- )
225
- target_type <: eltype (eltype (entry)) && return entry # multiplication case 3 or 4, first argument
199
+ S = eltype (eltype (entry))
200
+ T = eltype (parent (entry))
201
+ first_name = extract_first (append_internal_name (name_pair... ))
202
+ if ! hasfield (S, first_name) &&
203
+ ! (hasfield (S, :parent ) && hasfield (fieldtype (S, :parent ), first_name)) &&
204
+ S <: Geometry.SingleValue
205
+ if name_pair[1 ] == name_pair[2 ]
206
+ return entry # multiplication case 3 or 4, first arg
207
+ elseif is_overlapping_name (name_pair[1 ], name_pair[2 ])
208
+ throw (key_error)
209
+ else
210
+ throw (key_error)
211
+ return Base. broadcasted (entry) do matrix_row # off diagonal
212
+ map (matrix_row) do matrix_row_entry
213
+ zero (S)
214
+ end
215
+ end
216
+ end
217
+ end
218
+ # multiplication case 2 or 4, second arg
219
+ (index_offset, target_type) =
220
+ field_offset_and_type (name_pair, T, S, key_error)
226
221
if target_type <: eltype (parent (entry))
227
222
band_element_size =
228
223
DataLayouts. typesize (eltype (parent (entry)), eltype (eltype (entry)))
@@ -254,7 +249,6 @@ function get_internal_entry(
254
249
end
255
250
end
256
251
257
-
258
252
# Similar behavior to indexing an array with a slice.
259
253
function Base. getindex (dict:: FieldNameDict , new_keys:: FieldNameSet )
260
254
common_keys = intersect (keys (dict), new_keys)
@@ -320,76 +314,34 @@ function field_offset_and_type(
320
314
key_error,
321
315
) where {S, T}
322
316
name_pair == (@name (), @name ()) && return (0 , S) # base case
323
- if S <: Geometry.Axis2Tensor{T}
324
- (name_pair[1 ] == @name () || name_pair[2 ] == @name ()) && throw (key_error) # Cannot slice a 2D tensor
325
- row_index = extract_first (name_pair[1 ])
326
- col_index = extract_first (name_pair[2 ])
327
- if (! (row_index isa Number) || ! (col_index isa Number))
328
- name_pair[1 ] == name_pair[2 ] && return (0 , S) # multiplication case 3 or 4, first argument
329
- throw (key_error)
330
- end
331
- (n_rows, n_cols) = map (length, axes (S))
332
- @assert row_index <= n_rows && col_index <= n_cols
333
- return (n_rows * (col_index - 1 ) + row_index - 1 , T)
334
- elseif (
335
- name_pair[1 ] == name_pair[2 ] && ! broadcasted_has_field (S, name_pair[1 ])
336
- )
337
- return (0 , S)
338
- elseif name_pair[1 ] == @name ()
339
- return field_offset_and_type (name_pair[2 ], T, S, key_error)
340
- elseif name_pair[2 ] == @name ()
341
- return field_offset_and_type (name_pair[1 ], T, S, key_error)
342
- else
343
- child_name = extract_first (name_pair[1 ])
344
- (child_name in fieldnames (S)) || throw (key_error)
345
- child_type = fieldtype (S, child_name)
346
- remaining_field_chain = (drop_first (name_pair[1 ]), name_pair[2 ])
347
- field_index = unrolled_filter (
348
- i -> fieldname (S, i) == child_name,
349
- 1 : fieldcount (S),
350
- )[1 ]
351
- (remaining_offset, end_type) = field_offset_and_type (
352
- remaining_field_chain,
353
- T,
354
- child_type,
355
- key_error,
317
+ if S <: Geometry.Axis2Tensor{T} # special case to calculate index
318
+ (name_pair[1 ] == @name () || name_pair[2 ] == @name ()) && throw (key_error)
319
+ row_index = extract_first (
320
+ extract_internal_name (name_pair[1 ], @name (components. data)),
356
321
)
357
- return (
358
- DataLayouts. fieldtypeoffset (T, S, field_index) + remaining_offset,
359
- end_type,
360
- )
361
- end
362
- end
363
-
364
- """
365
- field_offset_and_type(name::FieldName, ::Type{T}, ::Type{S}, key_error)
366
-
367
- Returns the offset of the the field with name `name` in an object of type `S`
368
- in multiples of `sizeof(T)` and the type of the field with name `name` in an object of type `S`
369
- in multiples of `sizeof(T)`
370
- """
371
- function field_offset_and_type (
372
- name:: FieldName ,
373
- :: Type{T} ,
374
- :: Type{S} ,
375
- key_error,
376
- ) where {T, S}
377
- name == @name () && return (0 , S) # base case
378
- if S <: Geometry.AdjointAxisVector
379
- return field_offset_and_type (name, T, fieldtype (S, :parent ), key_error)
380
- elseif S <: Geometry.AxisVector
381
- (remaining_offset, end_type) = field_offset_and_type (
382
- name,
383
- T,
384
- fieldtype (fieldtype (S, :components ), :data ),
385
- key_error,
322
+ col_index = extract_first (
323
+ extract_internal_name (name_pair[2 ], @name (components. data)),
386
324
)
387
- return (remaining_offset, end_type)
325
+ ((row_index isa Number) && (col_index isa Number)) || throw (key_error) # slicing not supported
326
+ (n_rows, n_cols) = map (length, axes (S))
327
+ (row_index <= n_rows && col_index <= n_cols) || throw (key_error)
328
+ return (n_rows * (col_index - 1 ) + row_index - 1 , T)
388
329
else
389
- child_name = extract_first (name)
330
+ child_name, remaining_field_chain = if name_pair[1 ] != @name ()
331
+ extract_first (name_pair[1 ]), (drop_first (name_pair[1 ]), name_pair[2 ])
332
+ else
333
+ extract_first (name_pair[2 ]), (@name (), drop_first (name_pair[2 ]))
334
+ end
335
+ # indexing adjoint of axis vector is the same as indexing its parent
336
+ (S <: Geometry.AdjointAxisVector && child_name != :parent ) &&
337
+ return field_offset_and_type (
338
+ name_pair,
339
+ T,
340
+ fieldtype (S, 1 ),
341
+ key_error,
342
+ )
390
343
(child_name in fieldnames (S)) || throw (key_error)
391
344
child_type = fieldtype (S, child_name)
392
- remaining_field_chain = drop_first (name)
393
345
field_index = unrolled_filter (
394
346
i -> fieldname (S, i) == child_name,
395
347
1 : fieldcount (S),
@@ -422,7 +374,6 @@ when indexing `dict`.
422
374
"""
423
375
function get_scalar_keys (dict:: FieldMatrix , :: Type{FT} ) where {FT}
424
376
keys_tuple = unrolled_flatmap (keys (dict). values) do outer_key
425
- # target_eltype = eltype(Y)
426
377
unrolled_map (get_scalar_keys (eltype (dict[outer_key]), FT)) do inner_key
427
378
(
428
379
append_internal_name (outer_key[1 ], inner_key[1 ]),
@@ -444,24 +395,33 @@ function get_scalar_keys(::Type{T}, ::Type{FT}) where {T, FT}
444
395
return ((@name (), @name ()),)
445
396
elseif T <: BandMatrixRow
446
397
return get_scalar_keys (eltype (T), FT)
447
- elseif T <: Geometry.AxisVector
448
- return unrolled_map (1 : length (axes (T)[1 ])) do component
449
- (FieldName (component), @name ())
398
+ elseif T <: Geometry.Axis2Tensor
399
+ return unrolled_flatmap (1 : length (axes (T)[1 ])) do row_component
400
+ unrolled_map (1 : length (axes (T)[2 ])) do col_component
401
+ append_internal_name .(
402
+ Ref (@name (components. data)),
403
+ (FieldName (row_component), FieldName (col_component)),
404
+ )
405
+ end
450
406
end
451
407
elseif T <: Geometry.AdjointAxisVector
452
408
return unrolled_map (
453
- 1 : length ( axes ( fieldtype (T, :parent ))[ 1 ] ),
454
- ) do component
455
- (@name (), FieldName (component))
409
+ get_scalar_keys ( fieldtype (T, :parent ), FT ),
410
+ ) do inner_key
411
+ (inner_key[ 2 ], inner_key[ 1 ]) # assumes that adjoints only appear with d/dvec
456
412
end
457
- elseif T <: Geometry.Axis2Tensor
458
- unrolled_flatmap (1 : length (axes (T)[1 ])) do row_component
459
- unrolled_map (1 : length (axes (T)[2 ])) do col_component
460
- (FieldName (row_component), FieldName (col_component))
461
- end
413
+ elseif T <: Geometry.AxisVector # special case to avoid recursing into the axis field
414
+ # TODO : this should be able to be combined with the else case, but it causes runtime dispatch
415
+ return unrolled_map (
416
+ get_scalar_keys (fieldtype (T, :components ), FT),
417
+ ) do inner_key
418
+ (
419
+ append_internal_name (@name (components), inner_key[1 ]),
420
+ inner_key[2 ],
421
+ )
462
422
end
463
423
else
464
- unrolled_flatmap (fieldnames (T)) do inner_field
424
+ return unrolled_flatmap (fieldnames (T)) do inner_field
465
425
unrolled_map (
466
426
get_scalar_keys (fieldtype (T, inner_field), FT),
467
427
) do inner_key
0 commit comments