Skip to content

Commit e49e341

Browse files
committed
WIPP1
1 parent 8f92898 commit e49e341

File tree

2 files changed

+149
-101
lines changed

2 files changed

+149
-101
lines changed

src/MatrixFields/field_name_dict.jl

Lines changed: 141 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -163,31 +163,52 @@ function get_internal_entry(
163163
elseif T <: Geometry.Axis2Tensor &&
164164
all(n -> is_child_name(n, @name(components.data)), name_pair)
165165
# two indices needed to index into a 2d tensor (one can be Colon())
166-
row_index = extract_first(
167-
extract_internal_name(name_pair[1], @name(components.data)),
166+
internal_row_name =
167+
extract_internal_name(name_pair[1], @name(components.data))
168+
internal_col_name =
169+
extract_internal_name(name_pair[2], @name(components.data))
170+
row_index = extract_first(internal_row_name)
171+
col_index = extract_first(internal_col_name)
172+
return get_internal_entry(
173+
DiagonalMatrixRow(scaling_value(entry)[row_index, col_index]),
174+
(drop_first(internal_row_name), drop_first(internal_col_name)),
175+
key_error,
168176
)
169-
col_index = extract_first(
170-
extract_internal_name(name_pair[2], @name(components.data)),
177+
elseif T <: Geometry.AdjointAxisVector
178+
return get_internal_entry(
179+
DiagonalMatrixRow(getfield(scaling_value(entry), :parent)),
180+
name_pair,
181+
key_error,
171182
)
172-
return DiagonalMatrixRow(scaling_value(entry)[row_index, col_index])
173183
else
174-
combined_chain = append_internal_name(name_pair[1], name_pair[2])
175-
modified_chain =
176-
T <: Geometry.AdjointAxisVector && # indexing adjoint of axis vector
177-
extract_first(combined_chain) != :parent ? # is the same as its parent
178-
extract_internal_name(combined_chain, @name(components.data)) :
179-
combined_chain
180-
if !broadcasted_has_field(T, modified_chain) # implicit tensor structure case
181-
T <: Geometry.SingleValue || throw(key_error) # optimization only works with scalars
182-
hasfield(T, extract_first(combined_chain)) && throw(key_error)
183-
name_pair[1] == name_pair[2] && return entry # multiplication case 3, 4 first arg
184-
is_overlapping_name(name_pair[1], name_pair[2]) && throw(key_error)
185-
return zero(entry) # off diagonal
186-
else
187-
return DiagonalMatrixRow(
188-
broadcasted_get_field(scaling_value(entry), modified_chain),
189-
) # multiplication case 2, 4 second arg
190-
end
184+
child_name, remaining_chain =
185+
if name_pair[1] != @name() &&
186+
extract_first(name_pair[1]) in fieldnames(T)
187+
@inline (
188+
extract_first(name_pair[1]),
189+
(drop_first(name_pair[1]), name_pair[2]),
190+
)
191+
elseif name_pair[2] != @name() &&
192+
extract_first(name_pair[2]) in fieldnames(T)
193+
@inline (
194+
extract_first(name_pair[2]),
195+
(name_pair[1], drop_first(name_pair[2])),
196+
)
197+
elseif !any(isequal(@name()), name_pair) # implicit tensor structure
198+
return get_internal_entry(
199+
extract_first(name_pair[1]) == extract_first(name_pair[2]) ?
200+
entry : zero(entry),
201+
(drop_first(name_pair[1]), drop_first(name_pair[2])),
202+
key_error,
203+
)
204+
else
205+
throw(key_error)
206+
end
207+
return get_internal_entry(
208+
DiagonalMatrixRow(getfield(scaling_value(entry), child_name)),
209+
remaining_chain,
210+
key_error,
211+
)
191212
end
192213
end
193214
function get_internal_entry(
@@ -198,27 +219,9 @@ function get_internal_entry(
198219
name_pair == (@name(), @name()) && return entry
199220
S = eltype(eltype(entry))
200221
T = eltype(parent(entry))
201-
first_name = extract_first(append_internal_name(name_pair...))
202-
if !hasfield(S, first_name) &&
203-
!(hasfield(S, :parent) && hasfield(fieldtype(S, :parent), first_name)) &&
204-
S <: Geometry.SingleValue
205-
if name_pair[1] == name_pair[2]
206-
return entry # multiplication case 3 or 4, first arg
207-
elseif is_overlapping_name(name_pair[1], name_pair[2])
208-
throw(key_error)
209-
else
210-
throw(key_error)
211-
return Base.broadcasted(entry) do matrix_row # off diagonal
212-
map(matrix_row) do matrix_row_entry
213-
zero(S)
214-
end
215-
end
216-
end
217-
end
218-
# multiplication case 2 or 4, second arg
219-
(index_offset, target_type) =
222+
(start_offset, target_type, apply_zero) =
220223
field_offset_and_type(name_pair, T, S, key_error)
221-
if target_type <: eltype(parent(entry))
224+
if target_type <: eltype(parent(entry)) && !apply_zero
222225
band_element_size =
223226
DataLayouts.typesize(eltype(parent(entry)), eltype(eltype(entry)))
224227
singleton_datalayout = DataLayouts.singleton(Fields.field_values(entry))
@@ -227,7 +230,7 @@ function get_internal_entry(
227230
field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry))
228231
parent_indices = DataLayouts.to_data_specific_field(
229232
singleton_datalayout,
230-
(:, :, (index_offset + 1):band_element_size:field_dim_size, :, :),
233+
(:, :, (start_offset + 1):band_element_size:field_dim_size, :, :),
231234
)
232235
scalar_data = view(parent(entry), parent_indices...)
233236
values = DataLayouts.union_all(singleton_datalayout){
@@ -237,6 +240,12 @@ function get_internal_entry(
237240
scalar_data,
238241
)
239242
return Fields.Field(values, axes(entry))
243+
elseif apply_zero
244+
return Base.broadcasted(entry) do matrix_row
245+
map(matrix_row) do matrix_row_entry
246+
zero(target_type)
247+
end
248+
end
240249
else
241250
return Base.broadcasted(entry) do matrix_row
242251
map(matrix_row) do matrix_row_entry
@@ -248,6 +257,12 @@ function get_internal_entry(
248257
end
249258
end
250259
end
260+
if hasfield(Method, :recursion_relation)
261+
dont_limit = (args...) -> true
262+
for m in methods(get_internal_entry)
263+
m.recursion_relation = dont_limit
264+
end
265+
end
251266

252267
# Similar behavior to indexing an array with a slice.
253268
function Base.getindex(dict::FieldNameDict, new_keys::FieldNameSet)
@@ -313,40 +328,68 @@ function field_offset_and_type(
313328
::Type{S},
314329
key_error,
315330
) where {S, T}
316-
name_pair == (@name(), @name()) && return (0, S) # base case
317-
if S <: Geometry.Axis2Tensor{T} # special case to calculate index
331+
name_pair == (@name(), @name()) && return (0, S, false) # base case
332+
if S <: Geometry.Axis2Tensor # special case to calculate index
318333
(name_pair[1] == @name() || name_pair[2] == @name()) && throw(key_error)
319-
row_index = extract_first(
320-
extract_internal_name(name_pair[1], @name(components.data)),
321-
)
322-
col_index = extract_first(
323-
extract_internal_name(name_pair[2], @name(components.data)),
324-
)
334+
internal_row_name =
335+
extract_internal_name(name_pair[1], @name(components.data))
336+
internal_col_name =
337+
extract_internal_name(name_pair[2], @name(components.data))
338+
row_index = extract_first(internal_row_name)
339+
col_index = extract_first(internal_col_name)
325340
((row_index isa Number) && (col_index isa Number)) || throw(key_error) # slicing not supported
326341
(n_rows, n_cols) = map(length, axes(S))
342+
(remaining_offset, end_type, apply_zero) = field_offset_and_type(
343+
(drop_first(internal_row_name), drop_first(internal_col_name)),
344+
T,
345+
eltype(S),
346+
key_error,
347+
)
327348
(row_index <= n_rows && col_index <= n_cols) || throw(key_error)
328-
return (n_rows * (col_index - 1) + row_index - 1, T)
349+
return (
350+
(n_rows * (col_index - 1) + row_index - 1) + remaining_offset,
351+
end_type,
352+
apply_zero,
353+
)
354+
elseif S <: Geometry.AdjointAxisVector
355+
return field_offset_and_type(name_pair, T, fieldtype(S, 1), key_error)
329356
else
330-
child_name, remaining_field_chain = if name_pair[1] != @name()
331-
extract_first(name_pair[1]), (drop_first(name_pair[1]), name_pair[2])
332-
else
333-
extract_first(name_pair[2]), (@name(), drop_first(name_pair[2]))
334-
end
335-
# indexing adjoint of axis vector is the same as indexing its parent
336-
(S <: Geometry.AdjointAxisVector && child_name != :parent) &&
337-
return field_offset_and_type(
338-
name_pair,
339-
T,
340-
fieldtype(S, 1),
341-
key_error,
342-
)
343-
(child_name in fieldnames(S)) || throw(key_error)
357+
child_name, remaining_field_chain =
358+
if name_pair[1] != @name() &&
359+
extract_first(name_pair[1]) in fieldnames(S)
360+
@inline (
361+
extract_first(name_pair[1]),
362+
(drop_first(name_pair[1]), name_pair[2]),
363+
)
364+
elseif name_pair[2] != @name() &&
365+
extract_first(name_pair[2]) in fieldnames(S)
366+
@inline (
367+
extract_first(name_pair[2]),
368+
(name_pair[1], drop_first(name_pair[2])),
369+
)
370+
elseif !any(isequal(@name()), name_pair) # implicit tensor structure
371+
(remaining_offset, end_type, apply_zero) =
372+
field_offset_and_type(
373+
(drop_first(name_pair[1]), drop_first(name_pair[2])),
374+
T,
375+
fieldtype(S, 1),
376+
key_error,
377+
)
378+
return (
379+
remaining_offset,
380+
end_type,
381+
extract_first(name_pair[1]) == extract_first(name_pair[2]) ?
382+
apply_zero : true,
383+
)
384+
else
385+
throw(key_error)
386+
end
344387
child_type = fieldtype(S, child_name)
345388
field_index = unrolled_filter(
346389
i -> fieldname(S, i) == child_name,
347390
1:fieldcount(S),
348391
)[1]
349-
(remaining_offset, end_type) = field_offset_and_type(
392+
(remaining_offset, end_type, apply_zero) = field_offset_and_type(
350393
remaining_field_chain,
351394
T,
352395
child_type,
@@ -355,7 +398,9 @@ function field_offset_and_type(
355398
return (
356399
DataLayouts.fieldtypeoffset(T, S, field_index) + remaining_offset,
357400
end_type,
401+
apply_zero,
358402
)
403+
359404
end
360405
end
361406
if hasfield(Method, :recursion_relation)
@@ -366,16 +411,16 @@ if hasfield(Method, :recursion_relation)
366411
end
367412

368413
"""
369-
get_scalar_keys(dict::FieldMatrix, FT)
414+
get_scalar_keys(dict::FieldMatrix)
370415
371416
Returns a `FieldMatrixKeys` object that contains the keys that result in
372-
a `ScalingFieldMatrixEntry{FT}` or a `ColumnwiseBandMatrixField` with bands of eltype `FT`
417+
a `ScalingFieldMatrixEntry{<:Number}` or a `ColumnwiseBandMatrixField` with bands of eltype `< :Number`
373418
when indexing `dict`.
374419
"""
375-
function get_scalar_keys(dict::FieldMatrix, ::Type{FT}) where {FT}
420+
function get_scalar_keys(dict::FieldMatrix)
376421
keys_tuple = unrolled_flatmap(keys(dict).values) do outer_key
377-
unrolled_map(get_scalar_keys(eltype(dict[outer_key]), FT)) do inner_key
378-
(
422+
@inline unrolled_map(get_scalar_keys(eltype(dict[outer_key]))) do inner_key
423+
@inline (
379424
append_internal_name(outer_key[1], inner_key[1]),
380425
append_internal_name(outer_key[2], inner_key[2]),
381426
)
@@ -385,16 +430,16 @@ function get_scalar_keys(dict::FieldMatrix, ::Type{FT}) where {FT}
385430
end
386431

387432
"""
388-
get_scalar_keys(T::Type, FT::Type)
433+
get_scalar_keys(T::Type)
389434
390435
Returns a tuple of `FieldNamePair` objects that correspond to any children
391436
of `T` that are of type `FT`.
392437
"""
393-
function get_scalar_keys(::Type{T}, ::Type{FT}) where {T, FT}
394-
if T <: FT || T <: Bool # identity has eltype Bool
438+
function get_scalar_keys(::Type{T}) where {T}
439+
if T <: Number # TODO: is this tight enough of a Type? what about complex and duals and plushalfs
395440
return ((@name(), @name()),)
396441
elseif T <: BandMatrixRow
397-
return get_scalar_keys(eltype(T), FT)
442+
return get_scalar_keys(eltype(T))
398443
elseif T <: Geometry.Axis2Tensor
399444
return unrolled_flatmap(1:length(axes(T)[1])) do row_component
400445
unrolled_map(1:length(axes(T)[2])) do col_component
@@ -405,28 +450,35 @@ function get_scalar_keys(::Type{T}, ::Type{FT}) where {T, FT}
405450
end
406451
end
407452
elseif T <: Geometry.AdjointAxisVector
408-
return unrolled_map(
409-
get_scalar_keys(fieldtype(T, :parent), FT),
410-
) do inner_key
453+
return unrolled_map(get_scalar_keys(fieldtype(T, :parent))) do inner_key
411454
(inner_key[2], inner_key[1]) # assumes that adjoints only appear with d/dvec
412455
end
413456
elseif T <: Geometry.AxisVector # special case to avoid recursing into the axis field
414457
# TODO: this should be able to be combined with the else case, but it causes runtime dispatch
415458
return unrolled_map(
416-
get_scalar_keys(fieldtype(T, :components), FT),
459+
get_scalar_keys(fieldtype(T, :components)),
417460
) do inner_key
418461
(
419462
append_internal_name(@name(components), inner_key[1]),
420463
inner_key[2],
421464
)
422465
end
466+
# return unrolled_flatmap((:components,)) do inner_field
467+
# @inline unrolled_map(
468+
# # get_scalar_keys(fieldtype(T, inner_field)),
469+
# get_scalar_keys(fieldtype(T, :components)),
470+
# ) do inner_key
471+
# @inline (
472+
# append_internal_name(FieldName(inner_field), inner_key[1]),
473+
# inner_key[2],
474+
# )
475+
# end
476+
# end
423477
else
424-
return unrolled_flatmap(fieldnames(T)) do inner_field
425-
unrolled_map(
426-
get_scalar_keys(fieldtype(T, inner_field), FT),
427-
) do inner_key
478+
return unrolled_flatmap(fieldnames(T)) do inner_name
479+
unrolled_map(get_scalar_keys(fieldtype(T, inner_name))) do inner_key
428480
(
429-
append_internal_name(FieldName(inner_field), inner_key[1]),
481+
append_internal_name(FieldName(inner_name), inner_key[1]),
430482
inner_key[2],
431483
)
432484
end
@@ -442,7 +494,7 @@ end
442494

443495

444496
"""
445-
scalar_fieldmatrix(field_matrix::FieldMatrix, FT)
497+
scalar_fieldmatrix(field_matrix::FieldMatrix)
446498
447499
Constructs a `FieldNameDict` where the keys and entries are views
448500
of the entries of `field_matrix`, which corresponding to the
@@ -464,7 +516,7 @@ A = MatrixFields.FieldMatrix(
464516
(@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar,
465517
)
466518
467-
A_scalar = MatrixFields.scalar_fieldmatrix(A, Float64)
519+
A_scalar = MatrixFields.scalar_fieldmatrix(A)
468520
keys(A_scalar)
469521
# Output:
470522
# (@name(c.ρχ.ρq_liq), @name(f.u₃.:(1)))
@@ -473,8 +525,8 @@ keys(A_scalar)
473525
# (@name(c.uₕ.:(2)), @name(c.sgsʲs.:(1).ρa))
474526
```
475527
"""
476-
function scalar_fieldmatrix(field_matrix::FieldMatrix, ::Type{FT}) where {FT}
477-
scalar_keys = get_scalar_keys(field_matrix, FT)
528+
function scalar_fieldmatrix(field_matrix::FieldMatrix)
529+
scalar_keys = get_scalar_keys(field_matrix)
478530
entries = unrolled_map(scalar_keys.values) do key
479531
field_matrix[key]
480532
end

0 commit comments

Comments
 (0)