Skip to content

Commit 392e5ba

Browse files
author
Pietro Vertechi
authored
Add collection mechanism from iterable (JuliaArrays#22)
* general collection mechanism * cleanup * fix promoted_eltype * cleanup * fix recursive * cleanup * respect user type * add tests * test pass no inf * fix inference * multidimensional case
1 parent 8417b55 commit 392e5ba

File tree

5 files changed

+278
-6
lines changed

5 files changed

+278
-6
lines changed

src/StructArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export StructArray
55

66
include("structarray.jl")
77
include("utils.jl")
8+
include("collect.jl")
89

910
function __init__()
1011
Requires.@require Tables="bd369af6-aec1-5ad0-b16a-f7cc5008161c" include("tables.jl")

src/collect.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
struct StructArrayInitializer{F}
2+
unwrap::F
3+
end
4+
StructArrayInitializer() = StructArrayInitializer(t -> false)
5+
6+
const default_initializer = StructArrayInitializer()
7+
8+
(s::StructArrayInitializer)(S, d) = StructArray{S}(undef, d; unwrap = s.unwrap)
9+
10+
struct ArrayInitializer{F}
11+
unwrap::F
12+
end
13+
ArrayInitializer() = ArrayInitializer(t -> false)
14+
15+
(s::ArrayInitializer)(S, d) = _undef_array(S, d; unwrap = s.unwrap)
16+
17+
_reshape(v, itr, ::Base.HasShape) = reshape(v, axes(itr))
18+
_reshape(v, itr, ::Union{Base.HasLength, Base.SizeUnknown}) = v
19+
20+
function collect_columns(itr; initializer = default_initializer)
21+
sz = Base.IteratorSize(itr)
22+
v = collect_columns(itr, sz, initializer = initializer)
23+
_reshape(v, itr, sz)
24+
end
25+
26+
function collect_empty_columns(itr::T; initializer = default_initializer) where {T}
27+
S = Core.Compiler.return_type(first, Tuple{T})
28+
initializer(S, (0,))
29+
end
30+
31+
function collect_columns(@nospecialize(itr), ::Union{Base.HasShape, Base.HasLength};
32+
initializer = default_initializer)
33+
34+
st = iterate(itr)
35+
st === nothing && return collect_empty_columns(itr, initializer = initializer)
36+
el, i = st
37+
dest = initializer(typeof(el), (length(itr),))
38+
dest[1] = el
39+
collect_to_columns!(dest, itr, 2, i)
40+
end
41+
42+
function collect_to_columns!(dest::AbstractArray{T}, itr, offs, st) where {T}
43+
# collect to dest array, checking the type of each result. if a result does not
44+
# match, widen the result type and re-dispatch.
45+
i = offs
46+
while true
47+
elem = iterate(itr, st)
48+
elem === nothing && break
49+
el, st = elem
50+
if isa(el, T)
51+
@inbounds dest[i] = el
52+
i += 1
53+
else
54+
new = widencolumns(dest, i, el)
55+
@inbounds new[i] = el
56+
return collect_to_columns!(new, itr, i+1, st)
57+
end
58+
end
59+
return dest
60+
end
61+
62+
function collect_columns(itr, ::Base.SizeUnknown; initializer = default_initializer)
63+
elem = iterate(itr)
64+
elem === nothing && return collect_empty_columns(itr; initializer = initializer)
65+
el, st = elem
66+
dest = initializer(typeof(el), (1,))
67+
dest[1] = el
68+
grow_to_columns!(dest, itr, iterate(itr, st))
69+
end
70+
71+
function grow_to_columns!(dest::AbstractArray{T}, itr, elem = iterate(itr)) where {T}
72+
# collect to dest array, checking the type of each result. if a result does not
73+
# match, widen the result type and re-dispatch.
74+
i = length(dest)+1
75+
while elem !== nothing
76+
el, st = elem
77+
if isa(el, T)
78+
push!(dest, el)
79+
elem = iterate(itr, st)
80+
i += 1
81+
else
82+
new = widencolumns(dest, i, el)
83+
push!(new, el)
84+
return grow_to_columns!(new, itr, iterate(itr, st))
85+
end
86+
end
87+
return dest
88+
end
89+
90+
function to_structarray(::Type{T}, nt::C) where {T, C}
91+
S = createtype(T, C)
92+
StructArray{S}(nt)
93+
end
94+
95+
function widencolumns(dest::StructArray{T}, i, el::S) where {T, S}
96+
fs = fields(S)
97+
if fs === fields(T)
98+
new_cols = (widencolumns(columns(dest)[ind], i, getfieldindex(el, f, ind)) for (ind, f) in enumerate(fs))
99+
nt = NamedTuple{fs}(Tuple(new_cols))
100+
v = to_structarray(T, nt)
101+
else
102+
widenarray(dest, i, el)
103+
end
104+
end
105+
106+
widencolumns(dest::AbstractArray, i, el) = widenarray(dest, i, el)
107+
108+
function widenarray(dest::AbstractArray{T}, i, el::S) where {S, T}
109+
S <: T && return dest
110+
new = Array{promote_type(S, T)}(undef, length(dest))
111+
copyto!(new, 1, dest, 1, i-1)
112+
new
113+
end

src/structarray.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ struct StructArray{T, N, C<:NamedTuple} <: AbstractArray{T, N}
1818
end
1919

2020
StructArray{T}(c::C) where {T, C<:Tuple} = StructArray{T}(NamedTuple{fields(T)}(c))
21-
StructArray{T}(c::C) where {T, C<:NamedTuple} =
22-
StructArray{createtype(T, eltypes(C)), length(size(c[1])), C}(c)
23-
StructArray(c::C) where {C<:NamedTuple} = StructArray{C}(c)
21+
StructArray{T}(c::C) where {T, C<:NamedTuple} = StructArray{T, length(size(c[1])), C}(c)
22+
StructArray(c::C) where {C<:NamedTuple} = StructArray{eltypes(C)}(c)
2423
StructArray(c::C) where {C<:Tuple} = StructArray{eltypes(C)}(c)
2524

2625
StructArray{T}(; kwargs...) where {T} = StructArray{T}(values(kwargs))
@@ -122,3 +121,8 @@ for op in [:hcat, :vcat]
122121
end
123122

124123
Base.copy(s::StructArray{T,N,C}) where {T,N,C} = StructArray{T,N,C}(C(copy(x) for x in columns(s)))
124+
125+
function Base.reshape(s::StructArray{T}, d::Dims) where {T}
126+
StructArray{T}(map(x -> reshape(x, d), columns(s)))
127+
end
128+

src/utils.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import Base: tuple_type_cons, tuple_type_head, tuple_type_tail, tail
33
eltypes(::Type{Tuple{}}) = Tuple{}
44
eltypes(::Type{T}) where {T<:Tuple} =
55
tuple_type_cons(eltype(tuple_type_head(T)), eltypes(tuple_type_tail(T)))
6-
eltypes(::Type{NamedTuple{K, V}}) where {K, V} = eltypes(V)
6+
eltypes(::Type{NamedTuple{K, V}}) where {K, V} = NamedTuple{K, eltypes(V)}
77

88
Base.@pure SkipConstructor(::Type) = false
99

@@ -46,5 +46,12 @@ createinstance(::Type{T}, args...) where {T<:Union{Tuple, NamedTuple}} = T(args)
4646
Expr(:block, new_tup, construct)
4747
end
4848

49-
createtype(::Type{T}, ::Type{C}) where {T<:NamedTuple{N}, C} where {N} = NamedTuple{N, C}
50-
createtype(::Type{T}, ::Type{C}) where {T, C} = T
49+
createtype(::Type{T}, ::Type{NamedTuple{names, types}}) where {T, names, types} = createtype(T, names, eltypes(types))
50+
51+
createtype(::Type{T}, names, types) where {T} = T
52+
createtype(::Type{T}, names, types) where {T<:Tuple} = types
53+
createtype(::Type{<:NamedTuple{T}}, names, types) where {T} = NamedTuple{T, types}
54+
function createtype(::Type{<:Pair}, names, types)
55+
tp = types.parameters
56+
Pair{tp[1], tp[2]}
57+
end

test/runtests.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,150 @@ StructArrays.SkipConstructor(::Type{<:S}) = true
148148
@test v[1] == S(1)
149149
@test v[1].y isa Float64
150150
end
151+
152+
const initializer = StructArrays.ArrayInitializer(t -> t <: Union{Tuple, NamedTuple, Pair})
153+
collect_columns(t) = StructArrays.collect_columns(t, initializer = initializer)
154+
155+
@testset "collectnamedtuples" begin
156+
v = [(a = 1, b = 2), (a = 1, b = 3)]
157+
collect_columns(v) == StructArray((a = Int[1, 1], b = Int[2, 3]))
158+
159+
# test inferrability with constant eltype
160+
itr = [(a = 1, b = 2), (a = 1, b = 2), (a = 1, b = 12)]
161+
el, st = iterate(itr)
162+
dest = initializer(typeof(el), (3,))
163+
dest[1] = el
164+
@inferred StructArrays.collect_to_columns!(dest, itr, 2, st)
165+
166+
v = [(a = 1, b = 2), (a = 1.2, b = 3)]
167+
@test collect_columns(v) == StructArray((a = [1, 1.2], b = Int[2, 3]))
168+
@test typeof(collect_columns(v)) == typeof(StructArray((a = [1, 1.2], b = Int[2, 3])))
169+
170+
v = [(a = 1, b = 2), (a = 1.2, b = "3")]
171+
@test collect_columns(v) == StructArray((a = [1, 1.2], b = Any[2, "3"]))
172+
@test typeof(collect_columns(v)) == typeof(StructArray((a = [1, 1.2], b = Any[2, "3"])))
173+
174+
v = [(a = 1, b = 2), (a = 1.2, b = 2), (a = 1, b = "3")]
175+
@test collect_columns(v) == StructArray((a = [1, 1.2, 1], b = Any[2, 2, "3"]))
176+
@test typeof(collect_columns(v)) == typeof(StructArray((a = [1, 1.2, 1], b = Any[2, 2, "3"])))
177+
178+
# length unknown
179+
itr = Iterators.filter(isodd, 1:8)
180+
tuple_itr = ((a = i+1, b = i-1) for i in itr)
181+
@test collect_columns(tuple_itr) == StructArray((a = [2, 4, 6, 8], b = [0, 2, 4, 6]))
182+
tuple_itr_real = (i == 1 ? (a = 1.2, b =i-1) : (a = i+1, b = i-1) for i in itr)
183+
@test collect_columns(tuple_itr_real) == StructArray((a = Real[1.2, 4, 6, 8], b = [0, 2, 4, 6]))
184+
185+
# empty
186+
itr = Iterators.filter(t -> t > 10, 1:8)
187+
tuple_itr = ((a = i+1, b = i-1) for i in itr)
188+
@test collect_columns(tuple_itr) == StructArray((a = Int[], b = Int[]))
189+
190+
itr = (i for i in 0:-1)
191+
tuple_itr = ((a = i+1, b = i-1) for i in itr)
192+
@test collect_columns(tuple_itr) == StructArray((a = Int[], b = Int[]))
193+
end
194+
195+
@testset "collecttuples" begin
196+
v = [(1, 2), (1, 3)]
197+
@test collect_columns(v) == StructArray((Int[1, 1], Int[2, 3]))
198+
@inferred collect_columns(v)
199+
200+
@test StructArrays.collect_columns(v) == StructArray((Int[1, 1], Int[2, 3]))
201+
@inferred StructArrays.collect_columns(v)
202+
203+
v = [(1, 2), (1.2, 3)]
204+
@test collect_columns(v) == StructArray(([1, 1.2], Int[2, 3]))
205+
206+
v = [(1, 2), (1.2, "3")]
207+
@test collect_columns(v) == StructArray(([1, 1.2], Any[2, "3"]))
208+
@test typeof(collect_columns(v)) == typeof(StructArray(([1, 1.2], Any[2, "3"])))
209+
210+
v = [(1, 2), (1.2, 2), (1, "3")]
211+
@test collect_columns(v) == StructArray(([1, 1.2, 1], Any[2, 2, "3"]))
212+
# length unknown
213+
itr = Iterators.filter(isodd, 1:8)
214+
tuple_itr = ((i+1, i-1) for i in itr)
215+
@test collect_columns(tuple_itr) == StructArray(([2, 4, 6, 8], [0, 2, 4, 6]))
216+
tuple_itr_real = (i == 1 ? (1.2, i-1) : (i+1, i-1) for i in itr)
217+
@test collect_columns(tuple_itr_real) == StructArray(([1.2, 4, 6, 8], [0, 2, 4, 6]))
218+
@test typeof(collect_columns(tuple_itr_real)) == typeof(StructArray(([1.2, 4, 6, 8], [0, 2, 4, 6])))
219+
220+
# empty
221+
itr = Iterators.filter(t -> t > 10, 1:8)
222+
tuple_itr = ((i+1, i-1) for i in itr)
223+
@test collect_columns(tuple_itr) == StructArray((Int[], Int[]))
224+
225+
itr = (i for i in 0:-1)
226+
tuple_itr = ((i+1, i-1) for i in itr)
227+
@test collect_columns(tuple_itr) == StructArray((Int[], Int[]))
228+
end
229+
230+
@testset "collectscalars" begin
231+
v = (i for i in 1:3)
232+
@test collect_columns(v) == [1,2,3]
233+
@inferred collect_columns(v)
234+
235+
v = (i == 1 ? 1.2 : i for i in 1:3)
236+
@test collect_columns(v) == collect(v)
237+
238+
itr = Iterators.filter(isodd, 1:100)
239+
@test collect_columns(itr) == collect(itr)
240+
real_itr = (i == 1 ? 1.5 : i for i in itr)
241+
@test collect_columns(real_itr) == collect(real_itr)
242+
@test eltype(collect_columns(real_itr)) == Float64
243+
244+
#empty
245+
itr = Iterators.filter(t -> t > 10, 1:8)
246+
tuple_itr = (exp(i) for i in itr)
247+
@test collect_columns(tuple_itr) == Float64[]
248+
249+
itr = (i for i in 0:-1)
250+
tuple_itr = (exp(i) for i in itr)
251+
@test collect_columns(tuple_itr) == Float64[]
252+
253+
t = collect_columns((a = i,) for i in (1, missing, 3))
254+
@test StructArrays.columns(t)[1] isa Array{Union{Int, Missing}}
255+
@test isequal(StructArrays.columns(t)[1], [1, missing, 3])
256+
end
257+
258+
@testset "collectpairs" begin
259+
v = (i=>i+1 for i in 1:3)
260+
@test collect_columns(v) == StructArray{Pair{Int, Int}}([1,2,3], [2,3,4])
261+
@test eltype(collect_columns(v)) == Pair{Int, Int}
262+
263+
v = (i == 1 ? (1.2 => i+1) : (i => i+1) for i in 1:3)
264+
@test collect_columns(v) == StructArray{Pair{Float64, Int}}([1.2,2,3], [2,3,4])
265+
@test eltype(collect_columns(v)) == Pair{Float64, Int}
266+
267+
v = ((a=i,) => (b="a$i",) for i in 1:3)
268+
@test collect_columns(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = [1,2,3],)), StructArray((b = ["a1","a2","a3"],)))
269+
@test eltype(collect_columns(v)) == Pair{NamedTuple{(:a,), Tuple{Int64}}, NamedTuple{(:b,), Tuple{String}}}
270+
271+
v = (i == 1 ? (a="1",) => (b="a$i",) : (a=i,) => (b="a$i",) for i in 1:3)
272+
@test collect_columns(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Any}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = ["1",2,3],)), StructArray((b = ["a1","a2","a3"],)))
273+
@test eltype(collect_columns(v)) == Pair{NamedTuple{(:a,), Tuple{Any}}, NamedTuple{(:b,), Tuple{String}}}
274+
275+
# empty
276+
v = ((a=i,) => (b="a$i",) for i in 0:-1)
277+
@test collect_columns(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = Int[],)), StructArray((b = String[],)))
278+
@test eltype(collect_columns(v)) == Pair{NamedTuple{(:a,), Tuple{Int}}, NamedTuple{(:b,), Tuple{String}}}
279+
280+
v = Iterators.filter(t -> t.first.a == 4, ((a=i,) => (b="a$i",) for i in 1:3))
281+
@test collect_columns(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = Int[],)), StructArray((b = String[],)))
282+
@test eltype(collect_columns(v)) == Pair{NamedTuple{(:a,), Tuple{Int}}, NamedTuple{(:b,), Tuple{String}}}
283+
284+
t = collect_columns((b = 1,) => (a = i,) for i in (2, missing, 3))
285+
s = StructArray{Pair{NamedTuple{(:b,),Tuple{Int64}},NamedTuple{(:a,),Tuple{Union{Missing, Int64}}}}}(StructArray(b = [1,1,1]), StructArray(a = [2, missing, 3]))
286+
@test s[1] == t[1]
287+
@test ismissing(t[2].second.a)
288+
@test s[3] == t[3]
289+
end
290+
291+
@testset "collect2D" begin
292+
s = (l for l in [(a=i, b=j) for i in 1:3, j in 1:4])
293+
v = StructArrays.collect_columns(s)
294+
@test size(v) == (3, 4)
295+
@test v.a == [i for i in 1:3, j in 1:4]
296+
@test v.b == [j for i in 1:3, j in 1:4]
297+
end

0 commit comments

Comments
 (0)