Skip to content

Commit 0b6be8d

Browse files
author
Pietro Vertechi
authored
Use more specific check on widening (#81)
1 parent b0d5220 commit 0b6be8d

File tree

4 files changed

+50
-22
lines changed

4 files changed

+50
-22
lines changed

src/collect.jl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,27 +109,22 @@ function grow_to_structarray!(dest::AbstractArray, itr, elem = iterate(itr))
109109
return dest
110110
end
111111

112-
widenstructarray(dest::AbstractArray, i, el::S) where {S} = widenstructarray(dest, i, S)
113-
114-
function widenstructarray(dest::StructArray{T}, i, ::Type{S}) where {T, S}
115-
sch = staticschema(S)
116-
names = fieldnames(sch)
117-
types = ntuple(i -> fieldtype(sch, i), fieldcount(sch))
118-
cols = fieldarrays(dest)
119-
if names == propertynames(cols)
120-
nt = map((a, b) -> widenstructarray(a, i, b), cols, strip_params(sch)(types))
121-
ST = _promote_typejoin(S, T)
122-
return StructArray{ST}(nt)
123-
else
124-
return widenarray(dest, i, S)
125-
end
112+
widenstructarray(dest::AbstractArray{S}, i, el::T) where {S, T} = widenstructarray(dest, i, _promote_typejoin(S, T))
113+
114+
function widenstructarray(dest::StructArray, i, ::Type{T}) where {T}
115+
sch = hasfields(T) ? staticschema(T) : nothing
116+
sch !== nothing && fieldnames(sch) == propertynames(dest) || return widenarray(dest, i, T)
117+
types = ntuple(x -> fieldtype(sch, x), fieldcount(sch))
118+
cols = Tuple(fieldarrays(dest))
119+
newcols = map((a, b) -> widenstructarray(a, i, b), cols, types)
120+
return StructArray{T}(newcols)
126121
end
127122

128-
widenstructarray(dest::AbstractArray, i, ::Type{S}) where {S} = widenarray(dest, i, S)
123+
widenstructarray(dest::AbstractArray, i, ::Type{T}) where {T} = widenarray(dest, i, T)
129124

130-
widenarray(dest::AbstractArray{S}, i, ::Type{S}) where {S} = dest
131-
function widenarray(dest::AbstractArray{S}, i, ::Type{T}) where {S, T}
132-
new = similar(dest, Base.promote_typejoin(S, T), length(dest))
125+
widenarray(dest::AbstractArray{T}, i, ::Type{T}) where {T} = dest
126+
function widenarray(dest::AbstractArray, i, ::Type{T}) where T
127+
new = similar(dest, T, length(dest))
133128
copyto!(new, 1, dest, 1, i-1)
134129
new
135130
end

src/structarray.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,21 @@ StructVector(args...; kwargs...) = StructArray(args...; kwargs...)
5151

5252
Base.IndexStyle(::Type{S}) where {S<:StructArray} = _indexstyle(_best_index(S))
5353

54-
_undef_array(::Type{T}, sz; unwrap = t -> false) where {T} = unwrap(T) ? StructArray{T}(undef, sz; unwrap = unwrap) : Array{T}(undef, sz)
54+
function _undef_array(::Type{T}, sz; unwrap = t -> false) where {T}
55+
if unwrap(T)
56+
return StructArray{T}(undef, sz; unwrap = unwrap)
57+
else
58+
return Array{T}(undef, sz)
59+
end
60+
end
5561

56-
_similar(v::AbstractArray, ::Type{Z}; unwrap = t -> false) where {Z} =
57-
unwrap(Z) ? buildfromschema(typ -> _similar(v, typ; unwrap = unwrap), Z) : similar(v, Z)
62+
function _similar(v::AbstractArray, ::Type{Z}; unwrap = t -> false) where {Z}
63+
if unwrap(Z)
64+
return buildfromschema(typ -> _similar(v, typ; unwrap = unwrap), Z)
65+
else
66+
return similar(v, Z)
67+
end
68+
end
5869

5970
function StructArray{T}(::Base.UndefInitializer, sz::Dims; unwrap = t -> false) where {T}
6071
buildfromschema(typ -> _undef_array(typ, sz; unwrap = unwrap), T)
@@ -65,7 +76,10 @@ function similar_structarray(v::AbstractArray, ::Type{Z}; unwrap = t -> false) w
6576
buildfromschema(typ -> _similar(v, typ; unwrap = unwrap), Z)
6677
end
6778

68-
StructArray(v; unwrap = t -> false) = collect_structarray(v; initializer = StructArrayInitializer(unwrap))
79+
function StructArray(v; unwrap = t -> false)::StructArray
80+
collect_structarray(v; initializer = StructArrayInitializer(unwrap))
81+
end
82+
6983
function StructArray(v::AbstractArray{T}; unwrap = t -> false) where {T}
7084
s = similar_structarray(v, T; unwrap = unwrap)
7185
for i in eachindex(v)

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,9 @@ astuple(::Type{T}) where {T<:Tuple} = T
130130

131131
strip_params(::Type{<:Tuple}) = Tuple
132132
strip_params(::Type{<:NamedTuple{names}}) where {names} = NamedTuple{names}
133+
134+
hasfields(::Type{<:Tup}) = false
135+
hasfields(::Type{<:NTuple{N, Any}}) where {N} = true
136+
hasfields(::Type{<:NamedTuple{names}}) where {names} = true
137+
hasfields(::Type{T}) where {T} = !isabstracttype(T)
138+
hasfields(::Union) = false

test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,19 @@ end
530530
@test sa.a == fill(1, -2:7)
531531
end
532532

533+
@testset "hasfields" begin
534+
@test StructArrays.hasfields(ComplexF64)
535+
@test !StructArrays.hasfields(Any)
536+
@test StructArrays.hasfields(Tuple{Union{Int, Missing}})
537+
@test StructArrays.hasfields(typeof((a=1,)))
538+
@test !StructArrays.hasfields(NamedTuple)
539+
@test !StructArrays.hasfields(Tuple{Int, Vararg{Int, N}} where {N})
540+
@test StructArrays.hasfields(Missing)
541+
@test !StructArrays.hasfields(Union{Tuple{Int}, Missing})
542+
@test StructArrays.hasfields(Nothing)
543+
@test !StructArrays.hasfields(Union{Tuple{Int}, Nothing})
544+
end
545+
533546
@testset "reshape" begin
534547
s = StructArray(a=[1,2,3,4], b=["a","b","c","d"])
535548
rs = reshape(s, (2, 2))

0 commit comments

Comments
 (0)