Skip to content

Commit 8f92898

Browse files
committed
fix implicit tensor rep tests
1 parent 12b4c7b commit 8f92898

File tree

3 files changed

+98
-149
lines changed

3 files changed

+98
-149
lines changed

src/MatrixFields/field_name.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ extract_first(::FieldName{name_chain}) where {name_chain} = first(name_chain)
5050
drop_first(::FieldName{name_chain}) where {name_chain} =
5151
FieldName(Base.tail(name_chain)...)
5252

53-
extract_last(::FieldName{name_chain}) where {name_chain} =
54-
name_chain[length(name_chain)]
55-
5653
has_field(x, ::FieldName{()}) = true
5754
has_field(x, name::FieldName) =
5855
extract_first(name) in propertynames(x) &&

src/MatrixFields/field_name_dict.jl

Lines changed: 94 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -152,62 +152,42 @@ get_internal_key(child_name_pair::FieldNamePair, name_pair::FieldNamePair) = (
152152
get_internal_entry(entry, name::FieldName, key_error) = get_field(entry, name)
153153
get_internal_entry(entry, name_pair::FieldNamePair, key_error) =
154154
name_pair == (@name(), @name()) ? entry : throw(key_error)
155-
get_internal_entry(entry::UniformScaling, name_pair::FieldNamePair, key_error) =
156-
if name_pair[1] == name_pair[2]
157-
entry
158-
elseif is_overlapping_name(name_pair[1], name_pair[2])
159-
throw(key_error)
160-
else
161-
zero(entry)
162-
end
163155
function get_internal_entry(
164-
entry::DiagonalMatrixRow{T},
156+
entry::ScalingFieldMatrixEntry,
165157
name_pair::FieldNamePair,
166158
key_error,
167-
) where {T}
159+
)
160+
T = eltype(entry)
168161
if name_pair == (@name(), @name())
169162
return entry
170-
elseif T <: Geometry.Axis2Tensor
171-
(name_pair[1] == @name() || name_pair[2] == @name()) && throw(key_error) # Cannot slice a 2D tensor
172-
row_index = extract_first(name_pair[1])
173-
col_index = extract_first(name_pair[2])
174-
if (!(row_index isa Number) || !(col_index isa Number))
175-
name_pair[1] == name_pair[2] && return entry # multiplication case 3 or 4, first argument
176-
is_overlapping_name(name_pair[1], name_pair[2]) ||
177-
return zero(entry)
178-
throw(key_error)
179-
end
180-
(n_rows, n_cols) = map(length, axes(scaling_value(entry)))
181-
@assert row_index <= n_rows && col_index <= n_cols
182-
return DiagonalMatrixRow(
183-
broadcasted_get_field(
184-
scaling_value(entry).components.data,
185-
FieldName(n_rows * (col_index - 1) + row_index),
186-
),
163+
elseif T <: Geometry.Axis2Tensor &&
164+
all(n -> is_child_name(n, @name(components.data)), name_pair)
165+
# two indices needed to index into a 2d tensor (one can be Colon())
166+
row_index = extract_first(
167+
extract_internal_name(name_pair[1], @name(components.data)),
187168
)
188-
elseif T <: Geometry.AxisVectorOrAdj
189-
return get_internal_entry(
190-
DiagonalMatrixRow(
191-
getfield(
192-
getfield(
193-
T <: Geometry.AxisVector ? scaling_value(entry) :
194-
scaling_value(entry).parent,
195-
:components,
196-
),
197-
:data,
198-
),
199-
),
200-
name_pair,
201-
key_error,
169+
col_index = extract_first(
170+
extract_internal_name(name_pair[2], @name(components.data)),
202171
)
203-
elseif name_pair[1] == name_pair[2] &&
204-
!broadcasted_has_field(eltype(entry), name_pair[1])
205-
return entry
172+
return DiagonalMatrixRow(scaling_value(entry)[row_index, col_index])
206173
else
207-
non_empty_chain = name_pair[1] != @name() ? name_pair[1] : name_pair[2]
208-
return DiagonalMatrixRow(
209-
broadcasted_get_field(scaling_value(entry), non_empty_chain),
210-
)
174+
combined_chain = append_internal_name(name_pair[1], name_pair[2])
175+
modified_chain =
176+
T <: Geometry.AdjointAxisVector && # indexing adjoint of axis vector
177+
extract_first(combined_chain) != :parent ? # is the same as its parent
178+
extract_internal_name(combined_chain, @name(components.data)) :
179+
combined_chain
180+
if !broadcasted_has_field(T, modified_chain) # implicit tensor structure case
181+
T <: Geometry.SingleValue || throw(key_error) # optimization only works with scalars
182+
hasfield(T, extract_first(combined_chain)) && throw(key_error)
183+
name_pair[1] == name_pair[2] && return entry # multiplication case 3, 4 first arg
184+
is_overlapping_name(name_pair[1], name_pair[2]) && throw(key_error)
185+
return zero(entry) # off diagonal
186+
else
187+
return DiagonalMatrixRow(
188+
broadcasted_get_field(scaling_value(entry), modified_chain),
189+
) # multiplication case 2, 4 second arg
190+
end
211191
end
212192
end
213193
function get_internal_entry(
@@ -216,13 +196,28 @@ function get_internal_entry(
216196
key_error,
217197
)
218198
name_pair == (@name(), @name()) && return entry
219-
(index_offset, target_type) = field_offset_and_type(
220-
name_pair,
221-
eltype(parent(entry)),
222-
eltype(eltype(entry)),
223-
key_error,
224-
)
225-
target_type <: eltype(eltype(entry)) && return entry # multiplication case 3 or 4, first argument
199+
S = eltype(eltype(entry))
200+
T = eltype(parent(entry))
201+
first_name = extract_first(append_internal_name(name_pair...))
202+
if !hasfield(S, first_name) &&
203+
!(hasfield(S, :parent) && hasfield(fieldtype(S, :parent), first_name)) &&
204+
S <: Geometry.SingleValue
205+
if name_pair[1] == name_pair[2]
206+
return entry # multiplication case 3 or 4, first arg
207+
elseif is_overlapping_name(name_pair[1], name_pair[2])
208+
throw(key_error)
209+
else
210+
throw(key_error)
211+
return Base.broadcasted(entry) do matrix_row # off diagonal
212+
map(matrix_row) do matrix_row_entry
213+
zero(S)
214+
end
215+
end
216+
end
217+
end
218+
# multiplication case 2 or 4, second arg
219+
(index_offset, target_type) =
220+
field_offset_and_type(name_pair, T, S, key_error)
226221
if target_type <: eltype(parent(entry))
227222
band_element_size =
228223
DataLayouts.typesize(eltype(parent(entry)), eltype(eltype(entry)))
@@ -254,7 +249,6 @@ function get_internal_entry(
254249
end
255250
end
256251

257-
258252
# Similar behavior to indexing an array with a slice.
259253
function Base.getindex(dict::FieldNameDict, new_keys::FieldNameSet)
260254
common_keys = intersect(keys(dict), new_keys)
@@ -320,76 +314,34 @@ function field_offset_and_type(
320314
key_error,
321315
) where {S, T}
322316
name_pair == (@name(), @name()) && return (0, S) # base case
323-
if S <: Geometry.Axis2Tensor{T}
324-
(name_pair[1] == @name() || name_pair[2] == @name()) && throw(key_error) # Cannot slice a 2D tensor
325-
row_index = extract_first(name_pair[1])
326-
col_index = extract_first(name_pair[2])
327-
if (!(row_index isa Number) || !(col_index isa Number))
328-
name_pair[1] == name_pair[2] && return (0, S) # multiplication case 3 or 4, first argument
329-
throw(key_error)
330-
end
331-
(n_rows, n_cols) = map(length, axes(S))
332-
@assert row_index <= n_rows && col_index <= n_cols
333-
return (n_rows * (col_index - 1) + row_index - 1, T)
334-
elseif (
335-
name_pair[1] == name_pair[2] && !broadcasted_has_field(S, name_pair[1])
336-
)
337-
return (0, S)
338-
elseif name_pair[1] == @name()
339-
return field_offset_and_type(name_pair[2], T, S, key_error)
340-
elseif name_pair[2] == @name()
341-
return field_offset_and_type(name_pair[1], T, S, key_error)
342-
else
343-
child_name = extract_first(name_pair[1])
344-
(child_name in fieldnames(S)) || throw(key_error)
345-
child_type = fieldtype(S, child_name)
346-
remaining_field_chain = (drop_first(name_pair[1]), name_pair[2])
347-
field_index = unrolled_filter(
348-
i -> fieldname(S, i) == child_name,
349-
1:fieldcount(S),
350-
)[1]
351-
(remaining_offset, end_type) = field_offset_and_type(
352-
remaining_field_chain,
353-
T,
354-
child_type,
355-
key_error,
317+
if S <: Geometry.Axis2Tensor{T} # special case to calculate index
318+
(name_pair[1] == @name() || name_pair[2] == @name()) && throw(key_error)
319+
row_index = extract_first(
320+
extract_internal_name(name_pair[1], @name(components.data)),
356321
)
357-
return (
358-
DataLayouts.fieldtypeoffset(T, S, field_index) + remaining_offset,
359-
end_type,
360-
)
361-
end
362-
end
363-
364-
"""
365-
field_offset_and_type(name::FieldName, ::Type{T}, ::Type{S}, key_error)
366-
367-
Returns the offset of the the field with name `name` in an object of type `S`
368-
in multiples of `sizeof(T)` and the type of the field with name `name` in an object of type `S`
369-
in multiples of `sizeof(T)`
370-
"""
371-
function field_offset_and_type(
372-
name::FieldName,
373-
::Type{T},
374-
::Type{S},
375-
key_error,
376-
) where {T, S}
377-
name == @name() && return (0, S) # base case
378-
if S <: Geometry.AdjointAxisVector
379-
return field_offset_and_type(name, T, fieldtype(S, :parent), key_error)
380-
elseif S <: Geometry.AxisVector
381-
(remaining_offset, end_type) = field_offset_and_type(
382-
name,
383-
T,
384-
fieldtype(fieldtype(S, :components), :data),
385-
key_error,
322+
col_index = extract_first(
323+
extract_internal_name(name_pair[2], @name(components.data)),
386324
)
387-
return (remaining_offset, end_type)
325+
((row_index isa Number) && (col_index isa Number)) || throw(key_error) # slicing not supported
326+
(n_rows, n_cols) = map(length, axes(S))
327+
(row_index <= n_rows && col_index <= n_cols) || throw(key_error)
328+
return (n_rows * (col_index - 1) + row_index - 1, T)
388329
else
389-
child_name = extract_first(name)
330+
child_name, remaining_field_chain = if name_pair[1] != @name()
331+
extract_first(name_pair[1]), (drop_first(name_pair[1]), name_pair[2])
332+
else
333+
extract_first(name_pair[2]), (@name(), drop_first(name_pair[2]))
334+
end
335+
# indexing adjoint of axis vector is the same as indexing its parent
336+
(S <: Geometry.AdjointAxisVector && child_name != :parent) &&
337+
return field_offset_and_type(
338+
name_pair,
339+
T,
340+
fieldtype(S, 1),
341+
key_error,
342+
)
390343
(child_name in fieldnames(S)) || throw(key_error)
391344
child_type = fieldtype(S, child_name)
392-
remaining_field_chain = drop_first(name)
393345
field_index = unrolled_filter(
394346
i -> fieldname(S, i) == child_name,
395347
1:fieldcount(S),
@@ -422,7 +374,6 @@ when indexing `dict`.
422374
"""
423375
function get_scalar_keys(dict::FieldMatrix, ::Type{FT}) where {FT}
424376
keys_tuple = unrolled_flatmap(keys(dict).values) do outer_key
425-
# target_eltype = eltype(Y)
426377
unrolled_map(get_scalar_keys(eltype(dict[outer_key]), FT)) do inner_key
427378
(
428379
append_internal_name(outer_key[1], inner_key[1]),
@@ -444,24 +395,33 @@ function get_scalar_keys(::Type{T}, ::Type{FT}) where {T, FT}
444395
return ((@name(), @name()),)
445396
elseif T <: BandMatrixRow
446397
return get_scalar_keys(eltype(T), FT)
447-
elseif T <: Geometry.AxisVector
448-
return unrolled_map(1:length(axes(T)[1])) do component
449-
(FieldName(component), @name())
398+
elseif T <: Geometry.Axis2Tensor
399+
return unrolled_flatmap(1:length(axes(T)[1])) do row_component
400+
unrolled_map(1:length(axes(T)[2])) do col_component
401+
append_internal_name.(
402+
Ref(@name(components.data)),
403+
(FieldName(row_component), FieldName(col_component)),
404+
)
405+
end
450406
end
451407
elseif T <: Geometry.AdjointAxisVector
452408
return unrolled_map(
453-
1:length(axes(fieldtype(T, :parent))[1]),
454-
) do component
455-
(@name(), FieldName(component))
409+
get_scalar_keys(fieldtype(T, :parent), FT),
410+
) do inner_key
411+
(inner_key[2], inner_key[1]) # assumes that adjoints only appear with d/dvec
456412
end
457-
elseif T <: Geometry.Axis2Tensor
458-
unrolled_flatmap(1:length(axes(T)[1])) do row_component
459-
unrolled_map(1:length(axes(T)[2])) do col_component
460-
(FieldName(row_component), FieldName(col_component))
461-
end
413+
elseif T <: Geometry.AxisVector # special case to avoid recursing into the axis field
414+
# TODO: this should be able to be combined with the else case, but it causes runtime dispatch
415+
return unrolled_map(
416+
get_scalar_keys(fieldtype(T, :components), FT),
417+
) do inner_key
418+
(
419+
append_internal_name(@name(components), inner_key[1]),
420+
inner_key[2],
421+
)
462422
end
463423
else
464-
unrolled_flatmap(fieldnames(T)) do inner_field
424+
return unrolled_flatmap(fieldnames(T)) do inner_field
465425
unrolled_map(
466426
get_scalar_keys(fieldtype(T, inner_field), FT),
467427
) do inner_key

test/MatrixFields/scalar_fieldmatrix.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,31 +30,31 @@ include("matrix_field_test_utils.jl")
3030
(expected_offset, E)
3131
end
3232
test_field_offset_and_type(
33-
@name(x),
33+
(@name(x), @name()),
3434
FT,
3535
Singleton{Singleton{Singleton{Singleton{FT}}}},
3636
0,
3737
Singleton{Singleton{Singleton{FT}}},
3838
KeyError(@name(x.x.x.x)),
3939
)
4040
test_field_offset_and_type(
41-
@name(x.x.x.x),
41+
(@name(), @name(x.x.x.x)),
4242
FT,
4343
Singleton{Singleton{Singleton{Singleton{FT}}}},
4444
0,
4545
FT,
4646
KeyError(@name(x.x.x.x)),
4747
)
4848
test_field_offset_and_type(
49-
@name(y.x),
49+
(@name(), @name(y.x)),
5050
FT,
5151
TwoFields{TwoFields{FT, FT}, TwoFields{FT, FT}},
5252
2,
5353
FT,
5454
KeyError(@name(y.x)),
5555
)
5656
test_field_offset_and_type(
57-
@name(y.y),
57+
(@name(y), @name(y)),
5858
FT,
5959
TwoFields{
6060
TwoFields{FT, FT},
@@ -64,14 +64,6 @@ include("matrix_field_test_utils.jl")
6464
TwoFields{FT, Singleton{FT}},
6565
KeyError(@name(y.y.x)),
6666
)
67-
test_field_offset_and_type(
68-
@name(y.y),
69-
Float32,
70-
TwoFields{TwoFields{FT, FT}, TwoFields{FT, FT}},
71-
6,
72-
FT,
73-
KeyError(@name(y.y.x)),
74-
)
7567
test_field_offset_and_type(
7668
(@name(y.y), @name(x)),
7769
FT,

0 commit comments

Comments
 (0)