Skip to content

Commit 275ac0f

Browse files
jipolancopiever
andauthored
Fix inference issues of StructArray constructor on Julia 1.7+ (#209)
* Add fully inferred implementation of _map_params * Add new inference test * Add consistency test for _map_params * Add similar inference fix for map_params * Add map_params test * `_map_params_recursive` -> `_map_params_fallback` Co-authored-by: Pietro Vertechi <pietro.vertechi@protonmail.com> * `map_params_recursive` -> `map_params_fallback` Co-authored-by: Pietro Vertechi <pietro.vertechi@protonmail.com> * Use `fieldtypes` and remove `types_to_tuple` Co-authored-by: Pietro Vertechi <pietro.vertechi@protonmail.com>
1 parent 9057952 commit 275ac0f

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

src/utils.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,24 @@ julia> StructArrays.map_params(T -> Complex{T}, Tuple{Int32,Float64})
1313
Tuple{Complex{Int32},Complex{Float64}}
1414
```
1515
"""
16-
map_params(f, ::Type{Tuple{}}) = Tuple{}
17-
function map_params(f, ::Type{T}) where {T<:Tuple}
18-
tuple_type_cons(f(tuple_type_head(T)), map_params(f, tuple_type_tail(T)))
19-
end
2016
map_params(f, ::Type{NamedTuple{names, types}}) where {names, types} =
2117
NamedTuple{names, map_params(f, types)}
2218

19+
function map_params(f, ::Type{T}) where {T<:Tuple}
20+
if @generated
21+
types = fieldtypes(T)
22+
ex = :(Tuple{})
23+
for t types
24+
push!(ex.args, :(f($t)))
25+
end
26+
ex
27+
else
28+
map_params_fallback(f, T)
29+
end
30+
end
31+
32+
map_params_fallback(f, ::Type{T}) where {T<:Tuple} = Tuple{map(f, fieldtypes(T))...}
33+
2334
"""
2435
StructArrays._map_params(f, T)
2536
@@ -31,13 +42,24 @@ julia> StructArrays._map_params(T -> Complex{T}, Tuple{Int32,Float64})
3142
(Complex{Int32}, Complex{Float64})
3243
```
3344
"""
34-
_map_params(f, ::Type{Tuple{}}) = ()
35-
function _map_params(f, ::Type{T}) where {T<:Tuple}
36-
(f(tuple_type_head(T)), _map_params(f, tuple_type_tail(T))...)
37-
end
3845
_map_params(f::F, ::Type{NamedTuple{names, types}}) where {names, types, F} =
3946
NamedTuple{names}(_map_params(f, types))
4047

48+
function _map_params(f::F, ::Type{T}) where {T<:Tuple, F}
49+
if @generated
50+
types = fieldtypes(T)
51+
ex = :()
52+
for t types
53+
push!(ex.args, :(f($t)))
54+
end
55+
ex
56+
else
57+
_map_params_fallback(f, T)
58+
end
59+
end
60+
61+
_map_params_fallback(f, ::Type{T}) where {T<:Tuple} = map(f, fieldtypes(T))
62+
4163
buildfromschema(initializer::F, ::Type{T}) where {T, F} = buildfromschema(initializer, T, staticschema(T))
4264

4365
"""

test/runtests.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ f_infer() = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
342342
g_infer() = StructArray([(a=(b="1",), c=2)], unwrap = t -> t <: NamedTuple)
343343
tup_infer() = StructArray([(1, 2), (3, 4)])
344344
cols_infer() = StructArray(([1, 2], [1.2, 2.3]))
345+
nt_infer(nt) = StructArray{typeof(nt)}(undef, 4)
345346

346347
@testset "inferrability" begin
347348
@inferred f_infer()
@@ -352,6 +353,7 @@ cols_infer() = StructArray(([1, 2], [1.2, 2.3]))
352353
@test s[1] == (1, 2)
353354
@test s[2] == (3, 4)
354355
@inferred cols_infer()
356+
@inferred nt_infer((x = 3, y = :a, z = :b))
355357
end
356358

357359
@testset "propertynames" begin
@@ -879,3 +881,22 @@ end
879881
@inferred view(u, 1, :)
880882
end
881883
end
884+
885+
# Test fallback (non-@generated) variant of _map_params
886+
@testset "_map_params" begin
887+
v = StructArray(rand(ComplexF64, 2, 2))
888+
f(T) = similar(v, T)
889+
types = Tuple{Int, Float64, ComplexF32, String}
890+
A = @inferred StructArrays._map_params(f, types)
891+
B = StructArrays._map_params_fallback(f, types)
892+
@test typeof(A) === typeof(B)
893+
end
894+
895+
# Same for map_params
896+
@testset "map_params" begin
897+
types = Tuple{Int, Float64, Int32}
898+
f(T) = Complex{T}
899+
A = @inferred StructArrays.map_params(f, types)
900+
B = StructArrays.map_params_fallback(f, types)
901+
@test A === B
902+
end

0 commit comments

Comments
 (0)