Skip to content

Commit d8fae26

Browse files
committed
pre-switch to type space
1 parent f50ce58 commit d8fae26

File tree

1 file changed

+50
-42
lines changed

1 file changed

+50
-42
lines changed

src/MatrixFields/field_name_dict.jl

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -453,15 +453,10 @@ append_component_name(name::FieldName, component::Int) = append_internal_name(
453453
),
454454
)
455455

456-
# general strat is Recursive
457-
# first we start with dict and Y
458-
# 1. flatmap all key pairs into pairs plus appended stuffs
459-
# 1.1 if the eltype is not a scalar, then we recurse, and append.
460-
# 1.1 the resucrsion func should take in an entry, and a CC field and maybe a target type?
456+
461457
function gsk(dict::FieldMatrix, Y::Fields.FieldVector)
462458
target_eltype = eltype(Y)
463459
keys_tuple = unrolled_flatmap(keys(dict).values) do outer_key
464-
@show outer_key
465460
unrolled_map(
466461
gsk(
467462
dict[outer_key],
@@ -471,16 +466,15 @@ function gsk(dict::FieldMatrix, Y::Fields.FieldVector)
471466
),
472467
) do inner_key
473468
(
474-
append_internal_name(inner_key[1], outer_key[1]),
475-
append_internal_name(inner_key[2], outer_key[2]),
469+
append_internal_name(outer_key[1], inner_key[1]),
470+
append_internal_name(outer_key[2], inner_key[2]),
476471
)
477472
end
478473
end
474+
return FieldMatrixKeys(keys_tuple)
479475
end
480476

481-
# function gsk(entry, row_field, column_field, ::Type)
482-
# ((@name(), @name()),)
483-
# end
477+
484478

485479
gsk(entry::UniformScaling, row_field, column_field, ::Type) =
486480
((@name(), @name()),)
@@ -494,8 +488,8 @@ gsk(
494488
::Type{FT},
495489
) where {FT} = gsk(
496490
Fields.field_values(entry)[CartesianIndex(1, 1, 1, 1, 1)],
497-
row_field,
498-
column_field,
491+
Fields.field_values(row_field)[CartesianIndex(1, 1, 1, 1, 1)],
492+
Fields.field_values(column_field)[CartesianIndex(1, 1, 1, 1, 1)],
499493
FT,
500494
)
501495

@@ -512,59 +506,72 @@ gsk(
512506
gsk(value::FT, row_field, column_field, ::Type{FT}) where {FT} =
513507
((@name(), @name()),)
514508

509+
# generic fallback case
510+
function gsk(value::T, row_field, column_field, ::Type) where {T}
511+
return unrolled_map(fieldnames(T)) do inner_name
512+
(FieldName(inner_name), @name())
513+
end
514+
end
515+
515516
# dvec/dscalar
516517
function gsk(
517-
value::Geometry.AxisVectorOrAdj{FT},
518-
row_field::Fields.Field{
519-
<:DataLayouts.AbstractData{<:Geometry.AxisVectorOrAdj{FT}},
520-
},
521-
column_field::Fields.Field{<:DataLayouts.AbstractData{FT}},
518+
value::Geometry.AxisVector{FT},
519+
row_field::Geometry.AxisVector{FT},
520+
column_field::FT,
522521
::Type{FT},
523522
) where {FT}
524-
ncomponents = length(axes(value, 1))
525-
unrolled_map(propertynames(value)) do component_name
526-
(@name(component_name), @name())
523+
524+
unrolled_map(1:length(axes(value, 1))) do component
525+
(FieldName(component), @name())
527526
end
528527
end
529528

530529
# dscalar/dvec
531530
function gsk(
532-
value::Geometry.AxisVectorOrAdj{FT},
533-
row_field::Fields.Field{<:DataLayouts.AbstractData{FT}},
534-
column_field::Fields.Field{
535-
<:DataLayouts.AbstractData{<:Geometry.AxisVectorOrAdj{FT}},
536-
},
531+
value::Geometry.AdjointAxisVector{FT},
532+
row_field::FT,
533+
column_field::Geometry.AxisVector{FT},
537534
::Type{FT},
538535
) where {FT}
539-
Main.@infiltrate
540-
unrolled_map(propertynames(value)) do component_name
541-
(@name(), @name(component_name))
536+
unrolled_map(1:length(axes(value, 1))) do component
537+
(@name(), FieldName(component))
542538
end
543539
end
544540

545541
# dvec/dvec
546542
function gsk(
547543
value::Geometry.Axis2Tensor{FT},
548-
row_field::Fields.Field{
549-
<:DataLayouts.AbstractData{<:Geometry.AxisVectorOrAdj{FT}},
550-
},
551-
column_field::Fields.Field{
552-
<:DataLayouts.AbstractData{<:Geometry.AxisVectorOrAdj{FT}},
553-
},
544+
row_field::Geometry.AxisVectorOrAdj{FT},
545+
column_field::Geometry.AxisVectorOrAdj{FT},
554546
::Type{FT},
555547
) where {FT}
556-
# Main.@infiltrate
557-
((@name(), @name()),)
548+
unrolled_flatmap(1:length(axes(value, 1))) do row_component
549+
unrolled_map(1:length(axes(value, 2))) do col_component
550+
(FieldName(row_component), FieldName(col_component))
551+
end
552+
end
558553
end
559554

560555
# dtuple/dvec or dvec/dscalar
561556
function gsk(
562-
value::RT,
563-
row_field::Fields.Field{<:DataLayouts.AbstractData},
564-
column_field::Fields.Field{<:DataLayouts.AbstractData},
557+
value::TT,
558+
row_field::RT,
559+
column_field,
565560
::Type{FT},
566-
) where {FT, RT <: Union{NamedTuple, Tuple}}
567-
((@name(), @name()),)
561+
) where {FT, TT <: Union{NamedTuple, Tuple}, RT <: Union{NamedTuple, Tuple}}
562+
unrolled_flatmap(fieldnames(TT)) do tuple_key
563+
unrolled_map(gsk(
564+
getfield(value, tuple_key),
565+
getfield(row_field, tuple_key),
566+
column_field,
567+
FT,
568+
)) do inner_key
569+
(
570+
append_internal_name(FieldName(tuple_key), inner_key[1]),
571+
inner_key[2],
572+
)
573+
end
574+
end
568575
end
569576

570577
"""
@@ -723,6 +730,7 @@ keys(A_scalar)
723730
"""
724731
function scalar_fieldmatrix(field_matrix::FieldMatrix, Y::Fields.FieldVector)
725732
scalar_keys = get_scalar_keys(field_matrix, Y)
733+
# scalar_keys = gsk(field_matrix, Y)
726734
entries = unrolled_map(scalar_keys.values) do key
727735
field_matrix[key]
728736
end

0 commit comments

Comments
 (0)