Skip to content

Commit 95baaa6

Browse files
committed
fix broken tests
1 parent c3dacd4 commit 95baaa6

File tree

5 files changed

+57
-91
lines changed

5 files changed

+57
-91
lines changed

docs/src/matrix_fields.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ check_preconditioner
9090
lazy_or_concrete_preconditioner
9191
apply_preconditioner
9292
get_scalar_keys
93-
get_field_first_index_offset
94-
broadcasted_get_field_type
95-
inner_type_ignore_adjoint
93+
field_offset_and_type
9694
```
9795

9896
## Utilities
@@ -160,15 +158,16 @@ If the key `(@name(name1), @name(name2))` corresponds to an entry, then
160158
`(@name(name1.foo.bar.buz), @name(name2.biz.bop.fud))`.
161159

162160
Currently, internal values cannot be extracted in all situations. Extracting interal values
163-
works when:
161+
works when indexing an object of type `eltype(entry)` with the
162+
second key of the internal key pair appended to the first results in a scalar.
163+
If the internal keys index to a non-scalar `Field`, a broadcasted object is returned.
164164

165-
- The second name in the internal key is empty, and the first name in the internal key accesses internal values for the type of element contained in each row of the entry. This does not work when the element type of each row is a 2d tensor.
165+
When the entry is a `Field` of `Axis2Tensor`s, and both internal names are numbers that would index
166+
an `Axis2Tensor` with the same axis.
166167

167-
- The first name in the internal key is empty, and the type of element contained in each row of the entry is an `AxisVector` or the adjoint of an `AxisVector`. In this case, the second name must access inernal values for the type of `AxisVector` contained in each row.
168+
This does not work when the internal keys index to a `Field` of sliced tensors.
168169

169-
- The element type of each row in the entry is a 2d tensor, and the internal key is of the form `(@name(components.data.:(1)), @name(components.data.:(2)))`, but possibly with different numbers to index into the 2d tensor
170-
171-
- The element type of each row in the entry is some number of nested `Tuple`s and `NamedTuple`s, and the first name in the internal key accesses an `AxisVector` or the adjoint of an `AxisVector` from the outer `Tuple`/`NamedTuple`, and the second name in the inernal key accesses a component of the `AxisVector`
170+
Extracting internal values from a `DiagonalMatrixRow` works in all cases, except when
172171

173172
If the `FieldMatrix` represents a Jacobian, then extracting internal values works when an entry represents:
174173

src/MatrixFields/field_name.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,6 @@ get_field(x, ::FieldName{()}) = x
6262
get_field(x, name::FieldName) =
6363
get_field(getproperty(x, extract_first(name)), drop_first(name))
6464

65-
"""
66-
broadcasted_get_field_type(::Type{X}, name::FieldName)
67-
68-
Returns the type of the field accessed by `name` in the type `X`.
69-
"""
70-
broadcasted_get_field_type(::Type{X}, ::FieldName{()}) where {X} = X
71-
broadcasted_get_field_type(::Type{X}, name::FieldName) where {X} =
72-
broadcasted_get_field_type(
73-
fieldtype(X, extract_first(name)),
74-
drop_first(name),
75-
)
76-
7765
broadcasted_has_field(::Type{X}, ::FieldName{()}) where {X} = true
7866
broadcasted_has_field(::Type{X}, name::FieldName) where {X} =
7967
extract_first(name) in fieldnames(X) &&
@@ -214,7 +202,4 @@ if hasfield(Method, :recursion_relation)
214202
for m in methods(get_subtree_at_name)
215203
m.recursion_relation = dont_limit
216204
end
217-
for m in methods(broadcasted_get_field_type)
218-
m.recursion_relation = dont_limit
219-
end
220205
end

src/MatrixFields/field_name_dict.jl

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,15 @@ function get_internal_entry(
168168
if name_pair == (@name(), @name())
169169
return entry
170170
elseif T <: Geometry.Axis2Tensor
171-
(name_pair[1] == @name() || name_pair[2] == @name()) &&
172-
error("Cannot slice a 2D tensor")
171+
(name_pair[1] == @name() || name_pair[2] == @name()) && throw(key_error) # Cannot slice a 2D tensor
173172
row_index = extract_first(name_pair[1])
174173
col_index = extract_first(name_pair[2])
174+
if (!(row_index isa Number) || !(col_index isa Number))
175+
name_pair[1] == name_pair[2] && return entry # multiplication case 3 or 4, first argument
176+
is_overlapping_name(name_pair[1], name_pair[2]) ||
177+
return zero(entry)
178+
throw(key_error)
179+
end
175180
(n_rows, n_cols) = map(length, axes(scaling_value(entry)))
176181
@assert row_index <= n_rows && col_index <= n_cols
177182
return DiagonalMatrixRow(
@@ -215,7 +220,9 @@ function get_internal_entry(
215220
name_pair,
216221
eltype(parent(entry)),
217222
eltype(eltype(entry)),
223+
key_error,
218224
)
225+
target_type <: eltype(eltype(entry)) && return entry # multiplication case 3 or 4, first argument
219226
if target_type <: eltype(parent(entry))
220227
band_element_size =
221228
DataLayouts.typesize(eltype(parent(entry)), eltype(eltype(entry)))
@@ -295,7 +302,7 @@ function Base.one(matrix::FieldMatrix)
295302
end
296303

297304
"""
298-
field_offset_and_type(name_pair::FieldNamePair, ::Type{T}, ::Type{S})
305+
field_offset_and_type(name_pair::FieldNamePair, ::Type{T}, ::Type{S}, key_error)
299306
300307
Returns the offset of the field with name `name_pair` in an object of type `S` in
301308
multiples of `sizeof(T)` and the type of the field with name `name_pair`.
@@ -310,13 +317,17 @@ function field_offset_and_type(
310317
name_pair::FieldNamePair,
311318
::Type{T},
312319
::Type{S},
320+
key_error,
313321
) where {S, T}
314322
name_pair == (@name(), @name()) && return (0, S) # base case
315323
if S <: Geometry.Axis2Tensor{T}
316-
(name_pair[1] == @name() || name_pair[2] == @name()) &&
317-
error("Cannot slice a 2D tensor")
324+
(name_pair[1] == @name() || name_pair[2] == @name()) && throw(key_error) # Cannot slice a 2D tensor
318325
row_index = extract_first(name_pair[1])
319326
col_index = extract_first(name_pair[2])
327+
if (!(row_index isa Number) || !(col_index isa Number))
328+
name_pair[1] == name_pair[2] && return (0, S) # multiplication case 3 or 4, first argument
329+
throw(key_error)
330+
end
320331
(n_rows, n_cols) = map(length, axes(S))
321332
@assert row_index <= n_rows && col_index <= n_cols
322333
return (n_rows * (col_index - 1) + row_index - 1, T)
@@ -325,19 +336,24 @@ function field_offset_and_type(
325336
)
326337
return (0, S)
327338
elseif name_pair[1] == @name()
328-
return field_offset_and_type(name_pair[2], T, S)
339+
return field_offset_and_type(name_pair[2], T, S, key_error)
329340
elseif name_pair[2] == @name()
330-
return field_offset_and_type(name_pair[1], T, S)
341+
return field_offset_and_type(name_pair[1], T, S, key_error)
331342
else
332343
child_name = extract_first(name_pair[1])
344+
(child_name in fieldnames(S)) || throw(key_error)
333345
child_type = fieldtype(S, child_name)
334346
remaining_field_chain = (drop_first(name_pair[1]), name_pair[2])
335347
field_index = unrolled_filter(
336348
i -> fieldname(S, i) == child_name,
337349
1:fieldcount(S),
338350
)[1]
339-
(remaining_offset, end_type) =
340-
field_offset_and_type(remaining_field_chain, T, child_type)
351+
(remaining_offset, end_type) = field_offset_and_type(
352+
remaining_field_chain,
353+
T,
354+
child_type,
355+
key_error,
356+
)
341357
return (
342358
DataLayouts.fieldtypeoffset(T, S, field_index) + remaining_offset,
343359
end_type,
@@ -346,7 +362,7 @@ function field_offset_and_type(
346362
end
347363

348364
"""
349-
field_offset_and_type(name::FieldName, ::Type{T}, ::Type{S})
365+
field_offset_and_type(name::FieldName, ::Type{T}, ::Type{S}, key_error)
350366
351367
Returns the offset of the the field with name `name` in an object of type `S`
352368
in multiples of `sizeof(T)` and the type of the field with name `name` in an object of type `S`
@@ -356,27 +372,34 @@ function field_offset_and_type(
356372
name::FieldName,
357373
::Type{T},
358374
::Type{S},
375+
key_error,
359376
) where {T, S}
360377
name == @name() && return (0, S) # base case
361378
if S <: Geometry.AdjointAxisVector
362-
return field_offset_and_type(name, T, fieldtype(S, :parent))
379+
return field_offset_and_type(name, T, fieldtype(S, :parent), key_error)
363380
elseif S <: Geometry.AxisVector
364381
(remaining_offset, end_type) = field_offset_and_type(
365382
name,
366383
T,
367384
fieldtype(fieldtype(S, :components), :data),
385+
key_error,
368386
)
369387
return (remaining_offset, end_type)
370388
else
371389
child_name = extract_first(name)
390+
(child_name in fieldnames(S)) || throw(key_error)
372391
child_type = fieldtype(S, child_name)
373392
remaining_field_chain = drop_first(name)
374393
field_index = unrolled_filter(
375394
i -> fieldname(S, i) == child_name,
376395
1:fieldcount(S),
377396
)[1]
378-
(remaining_offset, end_type) =
379-
field_offset_and_type(remaining_field_chain, T, child_type)
397+
(remaining_offset, end_type) = field_offset_and_type(
398+
remaining_field_chain,
399+
T,
400+
child_type,
401+
key_error,
402+
)
380403
return (
381404
DataLayouts.fieldtypeoffset(T, S, field_index) + remaining_offset,
382405
end_type,

test/MatrixFields/field_names.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -834,12 +834,8 @@ end
834834

835835
@test_throws KeyError matrix[@name(a), @name(a.c)]
836836
@test_throws KeyError matrix[@name(a.c), @name(a)]
837-
@test_throws AssertionError matrix[@name(foo), @name(foo._value)]
838-
if is_scalar_test
839-
@test_throws KeyError matrix[@name(foo._value), @name(foo)]
840-
else
841-
@test_throws AssertionError matrix[@name(foo._value), @name(foo)]
842-
end
837+
@test_throws KeyError matrix[@name(foo), @name(foo._value)]
838+
@test_throws KeyError matrix[@name(foo._value), @name(foo)]
843839

844840
@test_all matrix[@name(a), @name(a)] == -I_a
845841
@test_all matrix[@name(a.c), @name(a.c)] == -I_a

test/MatrixFields/scalar_fieldmatrix.jl

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ include("matrix_field_test_utils.jl")
2424
::Type{S},
2525
expected_offset,
2626
::Type{E},
27+
key_error,
2728
) where {T, S, E}
28-
@test_all MatrixFields.field_offset_and_type(name, T, S) ==
29+
@test_all MatrixFields.field_offset_and_type(name, T, S, key_error) ==
2930
(expected_offset, E)
3031
end
3132
test_field_offset_and_type(
@@ -34,20 +35,23 @@ include("matrix_field_test_utils.jl")
3435
Singleton{Singleton{Singleton{Singleton{FT}}}},
3536
0,
3637
Singleton{Singleton{Singleton{FT}}},
38+
KeyError(@name(x.x.x.x)),
3739
)
3840
test_field_offset_and_type(
3941
@name(x.x.x.x),
4042
FT,
4143
Singleton{Singleton{Singleton{Singleton{FT}}}},
4244
0,
4345
FT,
46+
KeyError(@name(x.x.x.x)),
4447
)
4548
test_field_offset_and_type(
4649
@name(y.x),
4750
FT,
4851
TwoFields{TwoFields{FT, FT}, TwoFields{FT, FT}},
4952
2,
5053
FT,
54+
KeyError(@name(y.x)),
5155
)
5256
test_field_offset_and_type(
5357
@name(y.y),
@@ -58,78 +62,37 @@ include("matrix_field_test_utils.jl")
5862
},
5963
3,
6064
TwoFields{FT, Singleton{FT}},
65+
KeyError(@name(y.y.x)),
6166
)
6267
test_field_offset_and_type(
6368
@name(y.y),
6469
Float32,
6570
TwoFields{TwoFields{FT, FT}, TwoFields{FT, FT}},
6671
6,
6772
FT,
73+
KeyError(@name(y.y.x)),
6874
)
6975
test_field_offset_and_type(
70-
@name(y.y.x),
76+
(@name(y.y), @name(x)),
7177
FT,
7278
TwoFields{
7379
TwoFields{FT, FT},
7480
TwoFields{FT, TwoFields{FT, Singleton{FT}}},
7581
},
7682
3,
7783
FT,
84+
KeyError(@name(y.y.x.x)),
7885
)
7986
test_field_offset_and_type(
80-
@name(y.y.y.x),
87+
(@name(y.y.y), @name(y.x)),
8188
FT,
8289
TwoFields{
8390
TwoFields{FT, FT},
8491
TwoFields{FT, TwoFields{FT, Singleton{FT}}},
8592
},
8693
4,
8794
FT,
88-
)
89-
end
90-
91-
@testset "broadcasted_get_field_type" begin
92-
FT = Float64
93-
struct Singleton{T}
94-
x::T
95-
end
96-
struct TwoFields{T1, T2}
97-
x::T1
98-
y::T2
99-
end
100-
function test_broadcasted_get_field_type(
101-
name,
102-
::Type{T},
103-
expected_type,
104-
) where {T}
105-
@test_all MatrixFields.broadcasted_get_field_type(T, name) ==
106-
expected_type
107-
end
108-
test_broadcasted_get_field_type(
109-
@name(x),
110-
Singleton{Singleton{Singleton{Singleton{FT}}}},
111-
Singleton{Singleton{Singleton{FT}}},
112-
)
113-
test_broadcasted_get_field_type(
114-
@name(x.x.x),
115-
Singleton{Singleton{Singleton{Singleton{FT}}}},
116-
Singleton{FT},
117-
)
118-
test_broadcasted_get_field_type(
119-
@name(y.x),
120-
TwoFields{
121-
TwoFields{FT, FT},
122-
TwoFields{FT, TwoFields{FT, Singleton{FT}}},
123-
},
124-
FT,
125-
)
126-
test_broadcasted_get_field_type(
127-
@name(y.y.y),
128-
TwoFields{
129-
TwoFields{FT, FT},
130-
TwoFields{FT, TwoFields{FT, Singleton{FT}}},
131-
},
132-
Singleton{FT},
95+
KeyError(@name(y.y.y.x.x)),
13396
)
13497
end
13598

0 commit comments

Comments
 (0)