Skip to content

Commit 9db8de0

Browse files
authored
improve type stability and runtime performance (#194)
* specialize foreachfield on the function * improve inference of index_type (necessary for creating views)
1 parent 30ae00e commit 9db8de0

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

src/structarray.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ function index_type(::Type{T}) where {T<:Tuple}
3232
S, U = tuple_type_head(T), tuple_type_tail(T)
3333
IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U)
3434
end
35+
# Julia v1.7.0-beta3 doesn't seem to specialize `index_type` as defined above
36+
# for tuple types with "many" elements (three or four, depending on the concrete
37+
# types). However, we can help the compiler for homogeneous types by defining
38+
# the specialization below.
39+
function index_type(::Type{<:NTuple{N, S}}) where {N, S}
40+
if IndexStyle(S) isa IndexCartesian
41+
return CartesianIndex{ndims(S)}
42+
else
43+
return Int
44+
end
45+
end
3546

3647
index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I
3748

@@ -112,7 +123,7 @@ StructVector(args...; kwargs...) = StructArray(args...; kwargs...)
112123
"""
113124
StructArray{T}(A::AbstractArray; dims, unwrap=FT->FT!=eltype(A))
114125
115-
Construct a `StructArray` from slices of `A` along `dims`.
126+
Construct a `StructArray` from slices of `A` along `dims`.
116127
117128
The `unwrap` keyword argument is a function that determines whether to
118129
recursively convert fields of type `FT` to `StructArray`s.

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787
@generated foreachfield_gen(::S, f, xs::Vararg{Any, L}) where {S<:StructArray, L} =
8888
_foreachfield(array_names_types(S), L)
8989

90-
foreachfield(f, x::StructArray, xs...) = foreachfield_gen(x, f, x, xs...)
90+
foreachfield(f::F, x::StructArray, xs::Vararg{Any, N}) where {F, N} = foreachfield_gen(x, f, x, xs...)
9191

9292
"""
9393
StructArrays.iscompatible(::Type{S}, ::Type{V}) where {S, V<:AbstractArray}
@@ -149,7 +149,7 @@ julia> s = StructArray(a=1:3, b = fill("string", 3));
149149
julia> s_pooled = StructArrays.replace_storage(s) do v
150150
isbitstype(eltype(v)) ? v : convert(PooledArray, v)
151151
end
152-
$(if VERSION < v"1.6-"
152+
$(if VERSION < v"1.6-"
153153
"3-element StructArray(::UnitRange{Int64}, ::PooledArray{String,UInt32,1,Array{UInt32,1}}) with eltype NamedTuple{(:a, :b),Tuple{Int64,String}}:"
154154
else
155155
"3-element StructArray(::UnitRange{Int64}, ::PooledVector{String, UInt32, Vector{UInt32}}) with eltype NamedTuple{(:a, :b), Tuple{Int64, String}}:"

test/runtests.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ end
7070
@test s[100] == s[10, 10] == (a=1, b=1)
7171
s[10, 10] = (a=0, b=0)
7272
@test s[100] == s[10, 10] == (a=0, b=0)
73+
74+
# inference for "many" types, both for linear ad Cartesian indexing
75+
@inferred StructArrays.index_type(NTuple{2, Vector{Float64}})
76+
@inferred StructArrays.index_type(NTuple{3, Matrix{Float64}})
77+
@inferred StructArrays.index_type(NTuple{4, Array{Float64, 3}})
78+
79+
@inferred StructArrays.index_type(NTuple{2, SubArray{Float64, 1, Array{Float64, 2}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
80+
@inferred StructArrays.index_type(NTuple{3, SubArray{Float64, 1, Array{Float64, 2}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
81+
@inferred StructArrays.index_type(NTuple{4, SubArray{Float64, 1, Array{Float64, 2}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
7382
end
7483

7584
@testset "replace_storage" begin
@@ -818,9 +827,8 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
818827
end
819828

820829
@testset "staticarrays" begin
821-
822830
# test that staticschema returns the right things
823-
for StaticVectorType = [SVector, MVector, SizedVector]
831+
for StaticVectorType = [SVector, MVector, SizedVector]
824832
@test StructArrays.staticschema(StaticVectorType{2,Float64}) == Tuple{Float64,Float64}
825833
end
826834

@@ -838,4 +846,11 @@ end
838846
@test StructArrays.components(x) == ([1., 2.], [1., 2.])
839847
@test x .+ y == StructArray([StaticArrayType{Tuple{1,2}}(3*ones(1,2) .+ 2*i) for i = 0:1])
840848
end
849+
850+
# test type stability of creating views with "many" homogeneous components
851+
for n in 1:10
852+
u = StructArray(randn(SVector{n, Float64}) for _ in 1:10, _ in 1:5)
853+
@inferred view(u, :, 1)
854+
@inferred view(u, 1, :)
855+
end
841856
end

0 commit comments

Comments
 (0)