Skip to content

Commit bd70bc1

Browse files
author
Pietro Vertechi
authored
Merge pull request #35 from piever/pv/eltype
rename to collect_structarray
2 parents 1e5c581 + a56e6d0 commit bd70bc1

File tree

4 files changed

+75
-73
lines changed

4 files changed

+75
-73
lines changed

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@
88

99
### New features
1010

11-
- Added `collect_fieldarrays` function to collect an iterable of structs into a `StructArray` without having to allocate an array of structs
11+
- Added `collect_structarray` function to collect an iterable of structs into a `StructArray` without having to allocate an array of structs
1212
- `StructArray{T}(undef, dims)` and `StructArray(v::AbstractArray)` now support an `unwrap` keyword argument to specify on which types to do recursive unnesting of array of structs to struct of arrays
1313

src/StructArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module StructArrays
22

33
import Requires
44
export StructArray, StructVector
5+
export collect_structarray
56

67
include("interface.jl")
78
include("structarray.jl")

src/collect.jl

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,30 @@ ArrayInitializer(unwrap = t->false) = ArrayInitializer(unwrap, arrayof)
2424
_reshape(v, itr, ::Base.HasShape) = reshape(v, axes(itr))
2525
_reshape(v, itr, ::Union{Base.HasLength, Base.SizeUnknown}) = v
2626

27-
function collect_fieldarrays(itr; initializer = default_initializer)
27+
function collect_structarray(itr; initializer = default_initializer)
2828
sz = Base.IteratorSize(itr)
29-
v = collect_fieldarrays(itr, sz, initializer = initializer)
29+
v = collect_structarray(itr, sz, initializer = initializer)
3030
_reshape(v, itr, sz)
3131
end
3232

33-
function collect_empty_fieldarrays(itr::T; initializer = default_initializer) where {T}
33+
function collect_empty_structarray(itr::T; initializer = default_initializer) where {T}
3434
S = Core.Compiler.return_type(first, Tuple{T})
3535
initializer(S, (0,))
3636
end
3737

38-
function collect_fieldarrays(@nospecialize(itr), ::Union{Base.HasShape, Base.HasLength};
39-
initializer = default_initializer)
38+
function collect_structarray(itr, ::Union{Base.HasShape, Base.HasLength};
39+
initializer = default_initializer)
4040

4141
st = iterate(itr)
42-
st === nothing && return collect_empty_fieldarrays(itr, initializer = initializer)
42+
st === nothing && return collect_empty_structarray(itr, initializer = initializer)
4343
el, i = st
44-
dest = initializer(typeof(el), (length(itr),))
44+
S = typeof(el)
45+
dest = initializer(S, (length(itr),))
4546
dest[1] = el
46-
collect_to_fieldarrays!(dest, itr, 2, i)
47+
collect_to_structarray!(dest, itr, 2, i)
4748
end
4849

49-
function collect_to_fieldarrays!(dest::AbstractArray{T}, itr, offs, st) where {T}
50+
function collect_to_structarray!(dest::AbstractArray{T}, itr, offs, st) where {T}
5051
# collect to dest array, checking the type of each result. if a result does not
5152
# match, widen the result type and re-dispatch.
5253
i = offs
@@ -58,24 +59,24 @@ function collect_to_fieldarrays!(dest::AbstractArray{T}, itr, offs, st) where {T
5859
@inbounds dest[i] = el
5960
i += 1
6061
else
61-
new = widenfieldarrays(dest, i, el)
62+
new = widenstructarray(dest, i, el)
6263
@inbounds new[i] = el
63-
return collect_to_fieldarrays!(new, itr, i+1, st)
64+
return collect_to_structarray!(new, itr, i+1, st)
6465
end
6566
end
6667
return dest
6768
end
6869

69-
function collect_fieldarrays(itr, ::Base.SizeUnknown; initializer = default_initializer)
70+
function collect_structarray(itr, ::Base.SizeUnknown; initializer = default_initializer)
7071
elem = iterate(itr)
71-
elem === nothing && return collect_empty_fieldarrays(itr; initializer = initializer)
72+
elem === nothing && return collect_empty_structarray(itr; initializer = initializer)
7273
el, st = elem
7374
dest = initializer(typeof(el), (1,))
7475
dest[1] = el
75-
grow_to_fieldarrays!(dest, itr, iterate(itr, st))
76+
grow_to_structarray!(dest, itr, iterate(itr, st))
7677
end
7778

78-
function grow_to_fieldarrays!(dest::AbstractArray{T}, itr, elem = iterate(itr)) where {T}
79+
function grow_to_structarray!(dest::AbstractArray{T}, itr, elem = iterate(itr)) where {T}
7980
# collect to dest array, checking the type of each result. if a result does not
8081
# match, widen the result type and re-dispatch.
8182
i = length(dest)+1
@@ -86,9 +87,9 @@ function grow_to_fieldarrays!(dest::AbstractArray{T}, itr, elem = iterate(itr))
8687
elem = iterate(itr, st)
8788
i += 1
8889
else
89-
new = widenfieldarrays(dest, i, el)
90+
new = widenstructarray(dest, i, el)
9091
push!(new, el)
91-
return grow_to_fieldarrays!(new, itr, iterate(itr, st))
92+
return grow_to_structarray!(new, itr, iterate(itr, st))
9293
end
9394
end
9495
return dest
@@ -99,22 +100,22 @@ function to_structarray(::Type{T}, nt::C) where {T, C}
99100
StructArray{S}(nt)
100101
end
101102

102-
function widenfieldarrays(dest::StructArray{T}, i, el::S) where {T, S}
103+
function widenstructarray(dest::StructArray{T}, i, el::S) where {T, S}
103104
fs = fields(S)
104105
if fs === fields(T)
105-
new_cols = (widenfieldarrays(fieldarrays(dest)[ind], i, getfieldindex(el, f, ind)) for (ind, f) in enumerate(fs))
106+
new_cols = (widenstructarray(fieldarrays(dest)[ind], i, getfieldindex(el, f, ind)) for (ind, f) in enumerate(fs))
106107
nt = NamedTuple{fs}(Tuple(new_cols))
107108
v = to_structarray(T, nt)
108109
else
109110
widenarray(dest, i, el)
110111
end
111112
end
112113

113-
widenfieldarrays(dest::AbstractArray, i, el) = widenarray(dest, i, el)
114+
widenstructarray(dest::AbstractArray, i, el) = widenarray(dest, i, el)
114115

115116
function widenarray(dest::AbstractArray{T}, i, el::S) where {S, T}
116117
S <: T && return dest
117-
new = similar(dest, promote_type(S, T), length(dest))
118+
new = similar(dest, Base.promote_typejoin(S, T), length(dest))
118119
copyto!(new, 1, dest, 1, i-1)
119120
new
120121
end

test/runtests.jl

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -237,142 +237,142 @@ StructArrays.SkipConstructor(::Type{<:S}) = true
237237
end
238238

239239
const initializer = StructArrays.ArrayInitializer(t -> t <: Union{Tuple, NamedTuple, Pair})
240-
collect_fieldarrays_rec(t) = StructArrays.collect_fieldarrays(t, initializer = initializer)
240+
collect_structarray_rec(t) = collect_structarray(t, initializer = initializer)
241241

242242
@testset "collectnamedtuples" begin
243243
v = [(a = 1, b = 2), (a = 1, b = 3)]
244-
collect_fieldarrays_rec(v) == StructArray((a = Int[1, 1], b = Int[2, 3]))
244+
collect_structarray_rec(v) == StructArray((a = Int[1, 1], b = Int[2, 3]))
245245

246246
# test inferrability with constant eltype
247247
itr = [(a = 1, b = 2), (a = 1, b = 2), (a = 1, b = 12)]
248248
el, st = iterate(itr)
249249
dest = initializer(typeof(el), (3,))
250250
dest[1] = el
251-
@inferred StructArrays.collect_to_fieldarrays!(dest, itr, 2, st)
251+
@inferred StructArrays.collect_to_structarray!(dest, itr, 2, st)
252252

253253
v = [(a = 1, b = 2), (a = 1.2, b = 3)]
254-
@test collect_fieldarrays_rec(v) == StructArray((a = [1, 1.2], b = Int[2, 3]))
255-
@test typeof(collect_fieldarrays_rec(v)) == typeof(StructArray((a = [1, 1.2], b = Int[2, 3])))
254+
@test collect_structarray_rec(v) == StructArray((a = [1, 1.2], b = Int[2, 3]))
255+
@test typeof(collect_structarray_rec(v)) == typeof(StructArray((a = Real[1, 1.2], b = Int[2, 3])))
256256

257257
s = StructArray(a = [1, 2], b = [3, 4])
258-
@test StructArrays.collect_fieldarrays(LazyRow(s, i) for i in eachindex(s)) == s
259-
@test collect_fieldarrays_rec(LazyRow(s, i) for i in eachindex(s)) == s
258+
@test collect_structarray(LazyRow(s, i) for i in eachindex(s)) == s
259+
@test collect_structarray_rec(LazyRow(s, i) for i in eachindex(s)) == s
260260

261261
v = [(a = 1, b = 2), (a = 1.2, b = "3")]
262-
@test collect_fieldarrays_rec(v) == StructArray((a = [1, 1.2], b = Any[2, "3"]))
263-
@test typeof(collect_fieldarrays_rec(v)) == typeof(StructArray((a = [1, 1.2], b = Any[2, "3"])))
262+
@test collect_structarray_rec(v) == StructArray((a = [1, 1.2], b = Any[2, "3"]))
263+
@test typeof(collect_structarray_rec(v)) == typeof(StructArray((a = Real[1, 1.2], b = Any[2, "3"])))
264264

265265
v = [(a = 1, b = 2), (a = 1.2, b = 2), (a = 1, b = "3")]
266-
@test collect_fieldarrays_rec(v) == StructArray((a = [1, 1.2, 1], b = Any[2, 2, "3"]))
267-
@test typeof(collect_fieldarrays_rec(v)) == typeof(StructArray((a = [1, 1.2, 1], b = Any[2, 2, "3"])))
266+
@test collect_structarray_rec(v) == StructArray((a = Real[1, 1.2, 1], b = Any[2, 2, "3"]))
267+
@test typeof(collect_structarray_rec(v)) == typeof(StructArray((a = Real[1, 1.2, 1], b = Any[2, 2, "3"])))
268268

269269
# length unknown
270270
itr = Iterators.filter(isodd, 1:8)
271271
tuple_itr = ((a = i+1, b = i-1) for i in itr)
272-
@test collect_fieldarrays_rec(tuple_itr) == StructArray((a = [2, 4, 6, 8], b = [0, 2, 4, 6]))
272+
@test collect_structarray_rec(tuple_itr) == StructArray((a = [2, 4, 6, 8], b = [0, 2, 4, 6]))
273273
tuple_itr_real = (i == 1 ? (a = 1.2, b =i-1) : (a = i+1, b = i-1) for i in itr)
274-
@test collect_fieldarrays_rec(tuple_itr_real) == StructArray((a = Real[1.2, 4, 6, 8], b = [0, 2, 4, 6]))
274+
@test collect_structarray_rec(tuple_itr_real) == StructArray((a = Real[1.2, 4, 6, 8], b = [0, 2, 4, 6]))
275275

276276
# empty
277277
itr = Iterators.filter(t -> t > 10, 1:8)
278278
tuple_itr = ((a = i+1, b = i-1) for i in itr)
279-
@test collect_fieldarrays_rec(tuple_itr) == StructArray((a = Int[], b = Int[]))
279+
@test collect_structarray_rec(tuple_itr) == StructArray((a = Int[], b = Int[]))
280280

281281
itr = (i for i in 0:-1)
282282
tuple_itr = ((a = i+1, b = i-1) for i in itr)
283-
@test collect_fieldarrays_rec(tuple_itr) == StructArray((a = Int[], b = Int[]))
283+
@test collect_structarray_rec(tuple_itr) == StructArray((a = Int[], b = Int[]))
284284
end
285285

286286
@testset "collecttuples" begin
287287
v = [(1, 2), (1, 3)]
288-
@test collect_fieldarrays_rec(v) == StructArray((Int[1, 1], Int[2, 3]))
289-
@inferred collect_fieldarrays_rec(v)
288+
@test collect_structarray_rec(v) == StructArray((Int[1, 1], Int[2, 3]))
289+
@inferred collect_structarray_rec(v)
290290

291-
@test StructArrays.collect_fieldarrays(v) == StructArray((Int[1, 1], Int[2, 3]))
292-
@inferred StructArrays.collect_fieldarrays(v)
291+
@test collect_structarray(v) == StructArray((Int[1, 1], Int[2, 3]))
292+
@inferred collect_structarray(v)
293293

294294
v = [(1, 2), (1.2, 3)]
295-
@test collect_fieldarrays_rec(v) == StructArray(([1, 1.2], Int[2, 3]))
295+
@test collect_structarray_rec(v) == StructArray((Real[1, 1.2], Int[2, 3]))
296296

297297
v = [(1, 2), (1.2, "3")]
298-
@test collect_fieldarrays_rec(v) == StructArray(([1, 1.2], Any[2, "3"]))
299-
@test typeof(collect_fieldarrays_rec(v)) == typeof(StructArray(([1, 1.2], Any[2, "3"])))
298+
@test collect_structarray_rec(v) == StructArray((Real[1, 1.2], Any[2, "3"]))
299+
@test typeof(collect_structarray_rec(v)) == typeof(StructArray((Real[1, 1.2], Any[2, "3"])))
300300

301301
v = [(1, 2), (1.2, 2), (1, "3")]
302-
@test collect_fieldarrays_rec(v) == StructArray(([1, 1.2, 1], Any[2, 2, "3"]))
302+
@test collect_structarray_rec(v) == StructArray((Real[1, 1.2, 1], Any[2, 2, "3"]))
303303
# length unknown
304304
itr = Iterators.filter(isodd, 1:8)
305305
tuple_itr = ((i+1, i-1) for i in itr)
306-
@test collect_fieldarrays_rec(tuple_itr) == StructArray(([2, 4, 6, 8], [0, 2, 4, 6]))
306+
@test collect_structarray_rec(tuple_itr) == StructArray(([2, 4, 6, 8], [0, 2, 4, 6]))
307307
tuple_itr_real = (i == 1 ? (1.2, i-1) : (i+1, i-1) for i in itr)
308-
@test collect_fieldarrays_rec(tuple_itr_real) == StructArray(([1.2, 4, 6, 8], [0, 2, 4, 6]))
309-
@test typeof(collect_fieldarrays_rec(tuple_itr_real)) == typeof(StructArray(([1.2, 4, 6, 8], [0, 2, 4, 6])))
308+
@test collect_structarray_rec(tuple_itr_real) == StructArray((Real[1.2, 4, 6, 8], [0, 2, 4, 6]))
309+
@test typeof(collect_structarray_rec(tuple_itr_real)) == typeof(StructArray((Real[1.2, 4, 6, 8], [0, 2, 4, 6])))
310310

311311
# empty
312312
itr = Iterators.filter(t -> t > 10, 1:8)
313313
tuple_itr = ((i+1, i-1) for i in itr)
314-
@test collect_fieldarrays_rec(tuple_itr) == StructArray((Int[], Int[]))
314+
@test collect_structarray_rec(tuple_itr) == StructArray((Int[], Int[]))
315315

316316
itr = (i for i in 0:-1)
317317
tuple_itr = ((i+1, i-1) for i in itr)
318-
@test collect_fieldarrays_rec(tuple_itr) == StructArray((Int[], Int[]))
318+
@test collect_structarray_rec(tuple_itr) == StructArray((Int[], Int[]))
319319
end
320320

321321
@testset "collectscalars" begin
322322
v = (i for i in 1:3)
323-
@test collect_fieldarrays_rec(v) == [1,2,3]
324-
@inferred collect_fieldarrays_rec(v)
323+
@test collect_structarray_rec(v) == [1,2,3]
324+
@inferred collect_structarray_rec(v)
325325

326326
v = (i == 1 ? 1.2 : i for i in 1:3)
327-
@test collect_fieldarrays_rec(v) == collect(v)
327+
@test collect_structarray_rec(v) == collect(v)
328328

329329
itr = Iterators.filter(isodd, 1:100)
330-
@test collect_fieldarrays_rec(itr) == collect(itr)
330+
@test collect_structarray_rec(itr) == collect(itr)
331331
real_itr = (i == 1 ? 1.5 : i for i in itr)
332-
@test collect_fieldarrays_rec(real_itr) == collect(real_itr)
333-
@test eltype(collect_fieldarrays_rec(real_itr)) == Float64
332+
@test collect_structarray_rec(real_itr) == collect(real_itr)
333+
@test eltype(collect_structarray_rec(real_itr)) == Real
334334

335335
#empty
336336
itr = Iterators.filter(t -> t > 10, 1:8)
337337
tuple_itr = (exp(i) for i in itr)
338-
@test collect_fieldarrays_rec(tuple_itr) == Float64[]
338+
@test collect_structarray_rec(tuple_itr) == Float64[]
339339

340340
itr = (i for i in 0:-1)
341341
tuple_itr = (exp(i) for i in itr)
342-
@test collect_fieldarrays_rec(tuple_itr) == Float64[]
342+
@test collect_structarray_rec(tuple_itr) == Float64[]
343343

344-
t = collect_fieldarrays_rec((a = i,) for i in (1, missing, 3))
344+
t = collect_structarray_rec((a = i,) for i in (1, missing, 3))
345345
@test StructArrays.fieldarrays(t)[1] isa Array{Union{Int, Missing}}
346346
@test isequal(StructArrays.fieldarrays(t)[1], [1, missing, 3])
347347
end
348348

349349
@testset "collectpairs" begin
350350
v = (i=>i+1 for i in 1:3)
351-
@test collect_fieldarrays_rec(v) == StructArray{Pair{Int, Int}}([1,2,3], [2,3,4])
352-
@test eltype(collect_fieldarrays_rec(v)) == Pair{Int, Int}
351+
@test collect_structarray_rec(v) == StructArray{Pair{Int, Int}}([1,2,3], [2,3,4])
352+
@test eltype(collect_structarray_rec(v)) == Pair{Int, Int}
353353

354354
v = (i == 1 ? (1.2 => i+1) : (i => i+1) for i in 1:3)
355-
@test collect_fieldarrays_rec(v) == StructArray{Pair{Float64, Int}}([1.2,2,3], [2,3,4])
356-
@test eltype(collect_fieldarrays_rec(v)) == Pair{Float64, Int}
355+
@test collect_structarray_rec(v) == StructArray{Pair{Real, Int}}([1.2,2,3], [2,3,4])
356+
@test eltype(collect_structarray_rec(v)) == Pair{Real, Int}
357357

358358
v = ((a=i,) => (b="a$i",) for i in 1:3)
359-
@test collect_fieldarrays_rec(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = [1,2,3],)), StructArray((b = ["a1","a2","a3"],)))
360-
@test eltype(collect_fieldarrays_rec(v)) == Pair{NamedTuple{(:a,), Tuple{Int64}}, NamedTuple{(:b,), Tuple{String}}}
359+
@test collect_structarray_rec(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = [1,2,3],)), StructArray((b = ["a1","a2","a3"],)))
360+
@test eltype(collect_structarray_rec(v)) == Pair{NamedTuple{(:a,), Tuple{Int64}}, NamedTuple{(:b,), Tuple{String}}}
361361

362362
v = (i == 1 ? (a="1",) => (b="a$i",) : (a=i,) => (b="a$i",) for i in 1:3)
363-
@test collect_fieldarrays_rec(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Any}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = ["1",2,3],)), StructArray((b = ["a1","a2","a3"],)))
364-
@test eltype(collect_fieldarrays_rec(v)) == Pair{NamedTuple{(:a,), Tuple{Any}}, NamedTuple{(:b,), Tuple{String}}}
363+
@test collect_structarray_rec(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Any}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = ["1",2,3],)), StructArray((b = ["a1","a2","a3"],)))
364+
@test eltype(collect_structarray_rec(v)) == Pair{NamedTuple{(:a,), Tuple{Any}}, NamedTuple{(:b,), Tuple{String}}}
365365

366366
# empty
367367
v = ((a=i,) => (b="a$i",) for i in 0:-1)
368-
@test collect_fieldarrays_rec(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = Int[],)), StructArray((b = String[],)))
369-
@test eltype(collect_fieldarrays_rec(v)) == Pair{NamedTuple{(:a,), Tuple{Int}}, NamedTuple{(:b,), Tuple{String}}}
368+
@test collect_structarray_rec(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = Int[],)), StructArray((b = String[],)))
369+
@test eltype(collect_structarray_rec(v)) == Pair{NamedTuple{(:a,), Tuple{Int}}, NamedTuple{(:b,), Tuple{String}}}
370370

371371
v = Iterators.filter(t -> t.first.a == 4, ((a=i,) => (b="a$i",) for i in 1:3))
372-
@test collect_fieldarrays_rec(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = Int[],)), StructArray((b = String[],)))
373-
@test eltype(collect_fieldarrays_rec(v)) == Pair{NamedTuple{(:a,), Tuple{Int}}, NamedTuple{(:b,), Tuple{String}}}
372+
@test collect_structarray_rec(v) == StructArray{Pair{NamedTuple{(:a,),Tuple{Int64}},NamedTuple{(:b,),Tuple{String}}}}(StructArray((a = Int[],)), StructArray((b = String[],)))
373+
@test eltype(collect_structarray_rec(v)) == Pair{NamedTuple{(:a,), Tuple{Int}}, NamedTuple{(:b,), Tuple{String}}}
374374

375-
t = collect_fieldarrays_rec((b = 1,) => (a = i,) for i in (2, missing, 3))
375+
t = collect_structarray_rec((b = 1,) => (a = i,) for i in (2, missing, 3))
376376
s = StructArray{Pair{NamedTuple{(:b,),Tuple{Int64}},NamedTuple{(:a,),Tuple{Union{Missing, Int64}}}}}(StructArray(b = [1,1,1]), StructArray(a = [2, missing, 3]))
377377
@test s[1] == t[1]
378378
@test ismissing(t[2].second.a)
@@ -381,7 +381,7 @@ end
381381

382382
@testset "collect2D" begin
383383
s = (l for l in [(a=i, b=j) for i in 1:3, j in 1:4])
384-
v = StructArrays.collect_fieldarrays(s)
384+
v = collect_structarray(s)
385385
@test size(v) == (3, 4)
386386
@test v.a == [i for i in 1:3, j in 1:4]
387387
@test v.b == [j for i in 1:3, j in 1:4]

0 commit comments

Comments
 (0)