Skip to content

Commit 3ab2609

Browse files
author
Pietro Vertechi
authored
Allow recursive unwrapping (JuliaArrays#19)
* wip allow recursive unwrapping * added unwrap keyword argument * Add tests * Use keyword argument anyway and define _similar * spacing * rename to _undef_array
1 parent 4d7a5a4 commit 3ab2609

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

src/structarray.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,23 @@ StructArray{T}(; kwargs...) where {T} = StructArray{T}(values(kwargs))
2727
StructArray(; kwargs...) = StructArray(values(kwargs))
2828

2929
StructArray{T}(args...) where {T} = StructArray{T}(NamedTuple{fields(T)}(args))
30-
@generated function StructArray{T}(::Base.UndefInitializer, d::Integer...) where {T}
31-
ex = Expr(:tuple, [:(Array{$(fieldtype(T, i))}(undef, sz)) for i in 1:fieldcount(T)]...)
30+
31+
_undef_array(::Type{T}, sz; unwrap = t -> false) where {T} = unwrap(T) ? StructArray{T}(undef, sz; unwrap = unwrap) : Array{T}(undef, sz)
32+
function _similar(v::S, ::Type{Z}; unwrap = t -> false) where {S <: AbstractArray{T, N}, Z} where {T, N}
33+
unwrap(Z) ? StructArray{Z}(map(t -> _similar(v, fieldtype(Z, t); unwrap = unwrap), fields(Z))) : similar(v, Z)
34+
end
35+
36+
StructArray{T}(u::Base.UndefInitializer, d::Integer...; unwrap = t -> false) where {T} = StructArray{T}(u, convert(Dims, d); unwrap = unwrap)
37+
@generated function StructArray{T}(::Base.UndefInitializer, sz::Dims; unwrap = t -> false) where {T}
38+
ex = Expr(:tuple, [:(_undef_array($(fieldtype(T, i)), sz; unwrap = unwrap)) for i in 1:fieldcount(T)]...)
3239
return quote
33-
sz = convert(Tuple{Vararg{Int}}, d)
3440
StructArray{T}(NamedTuple{fields(T)}($ex))
3541
end
3642
end
3743

38-
@generated function StructArray(v::AbstractArray{T, N}) where {T, N}
44+
@generated function StructArray(v::AbstractArray{T, N}; unwrap = t -> false) where {T, N}
3945
syms = [gensym() for i in 1:fieldcount(T)]
40-
init = Expr(:block, [:($(syms[i]) = similar(v, $(fieldtype(T, i)))) for i in 1:fieldcount(T)]...)
46+
init = Expr(:block, [:($(syms[i]) = _similar(v, $(fieldtype(T, i)); unwrap = unwrap)) for i in 1:fieldcount(T)]...)
4147
push = Expr(:block, [:($(syms[i])[j] = f.$(fieldname(T, i))) for i in 1:fieldcount(T)]...)
4248
quote
4349
$init

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,15 @@ end
9696
end
9797

9898
f_infer() = StructArray{ComplexF64}(rand(2,2), rand(2,2))
99+
unwrap(::Type) = false
100+
unwrap(::Type{<:NamedTuple}) = true
101+
102+
g_infer() = StructArray([(a=(b=1,), c=2)], unwrap = unwrap)
103+
99104
@testset "inferrability" begin
100105
@inferred f_infer()
106+
@inferred g_infer()
107+
@test g_infer().a.b == [1]
101108
end
102109

103110
@testset "propertynames" begin

0 commit comments

Comments
 (0)