Skip to content

Commit 3359521

Browse files
author
Pietro Vertechi
authored
Cleanup and remove generated functions as much as possible (JuliaArrays#23)
1 parent 392e5ba commit 3359521

File tree

6 files changed

+88
-70
lines changed

6 files changed

+88
-70
lines changed

src/StructArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module StructArrays
33
import Requires
44
export StructArray
55

6+
include("interface.jl")
67
include("structarray.jl")
78
include("utils.jl")
89
include("collect.jl")

src/interface.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
@generated function staticschema(::Type{T}) where {T}
2+
name_tuple = Expr(:tuple, [QuoteNode(f) for f in fieldnames(T)]...)
3+
type_tuple = Expr(:curly, :Tuple, [Expr(:call, :fieldtype, :T, i) for i in 1:fieldcount(T)]...)
4+
Expr(:curly, :NamedTuple, name_tuple, type_tuple)
5+
end
6+
7+
@generated function staticschema(::Type{T}) where {T<:Tuple}
8+
name_tuple = Expr(:tuple, [QuoteNode(Symbol("x$f")) for f in fieldnames(T)]...)
9+
type_tuple = Expr(:curly, :Tuple, [Expr(:call, :fieldtype, :T, i) for i in 1:fieldcount(T)]...)
10+
Expr(:curly, :NamedTuple, name_tuple, type_tuple)
11+
end
12+
13+
staticschema(::Type{T}) where {T<:NamedTuple} = T
14+
15+
getnames(::Type{NamedTuple{names, types}}) where {names, types} = names
16+
gettypes(::Type{NamedTuple{names, types}}) where {names, types} = types
17+
18+
function fields(::Type{T}) where {T}
19+
getnames(staticschema(T))
20+
end

src/structarray.jl

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,25 @@ StructArray(; kwargs...) = StructArray(values(kwargs))
2828
StructArray{T}(args...) where {T} = StructArray{T}(NamedTuple{fields(T)}(args))
2929

3030
_undef_array(::Type{T}, sz; unwrap = t -> false) where {T} = unwrap(T) ? StructArray{T}(undef, sz; unwrap = unwrap) : Array{T}(undef, sz)
31-
function _similar(v::S, ::Type{Z}; unwrap = t -> false) where {S <: AbstractArray{T, N}, Z} where {T, N}
32-
unwrap(Z) ? StructArray{Z}(map(t -> _similar(v, fieldtype(Z, t); unwrap = unwrap), fields(Z))) : similar(v, Z)
33-
end
3431

32+
_similar(v::AbstractArray, ::Type{Z}; unwrap = t -> false) where {Z} =
33+
unwrap(Z) ? buildfromschema(typ -> _similar(v, typ; unwrap = unwrap), Z) : similar(v, Z)
34+
35+
function StructArray{T}(::Base.UndefInitializer, sz::Dims; unwrap = t -> false) where {T}
36+
buildfromschema(typ -> _undef_array(typ, sz; unwrap = unwrap), T)
37+
end
3538
StructArray{T}(u::Base.UndefInitializer, d::Integer...; unwrap = t -> false) where {T} = StructArray{T}(u, convert(Dims, d); unwrap = unwrap)
36-
@generated function StructArray{T}(::Base.UndefInitializer, sz::Dims; unwrap = t -> false) where {T}
37-
ex = Expr(:tuple, [:(_undef_array($(fieldtype(T, i)), sz; unwrap = unwrap)) for i in 1:fieldcount(T)]...)
38-
return quote
39-
StructArray{T}(NamedTuple{fields(T)}($ex))
40-
end
39+
40+
function similar_structarray(v::AbstractArray, ::Type{Z}; unwrap = t -> false) where {Z}
41+
buildfromschema(typ -> _similar(v, typ; unwrap = unwrap), Z)
4142
end
4243

43-
@generated function StructArray(v::AbstractArray{T, N}; unwrap = t -> false) where {T, N}
44-
syms = [gensym() for i in 1:fieldcount(T)]
45-
init = Expr(:block, [:($(syms[i]) = _similar(v, $(fieldtype(T, i)); unwrap = unwrap)) for i in 1:fieldcount(T)]...)
46-
push = Expr(:block, [:($(syms[i])[j] = getfield(f, $i)) for i in 1:fieldcount(T)]...)
47-
quote
48-
$init
49-
for (j, f) in enumerate(v)
50-
@inbounds $push
51-
end
52-
return StructArray{T}($(syms...))
44+
function StructArray(v::AbstractArray{T}; unwrap = t -> false) where {T}
45+
s = similar_structarray(v, T; unwrap = unwrap)
46+
for i in eachindex(v)
47+
@inbounds s[i] = v[i]
5348
end
49+
s
5450
end
5551
StructArray(s::StructArray) = copy(s)
5652

@@ -63,7 +59,15 @@ Base.propertynames(s::StructArray) = fieldnames(typeof(columns(s)))
6359

6460
Base.size(s::StructArray) = size(columns(s)[1])
6561

66-
Base.@propagate_inbounds Base.getindex(s::StructArray, I::Int...) = get_ith(s, I...)
62+
@generated function Base.getindex(x::StructArray{T, N, NamedTuple{names, types}}, I::Int...) where {T, N, names, types}
63+
args = [:(getfield(cols, $i)[I...]) for i in 1:length(names)]
64+
return quote
65+
cols = columns(x)
66+
@boundscheck checkbounds(x, I...)
67+
@inbounds $(Expr(:call, :createinstance, :T, args...))
68+
end
69+
end
70+
6771
function Base.getindex(s::StructArray{T, N, C}, I::Union{Int, AbstractArray, Colon}...) where {T, N, C}
6872
StructArray{T}(map(v -> getindex(v, I...), columns(s)))
6973
end
@@ -72,29 +76,21 @@ function Base.view(s::StructArray{T, N, C}, I...) where {T, N, C}
7276
StructArray{T}(map(v -> view(v, I...), columns(s)))
7377
end
7478

75-
Base.@propagate_inbounds Base.setindex!(s::StructArray, val, I::Int...) = set_ith!(s, val, I...)
76-
77-
fields(::Type{<:NamedTuple{K}}) where {K} = K
78-
@generated function fields(t::Type{T}) where {T}
79-
return :($(Expr(:tuple, [QuoteNode(f) for f in fieldnames(T)]...)))
80-
end
81-
@generated function fields(t::Type{T}) where {T<:Tuple}
82-
return :($(Expr(:tuple, [QuoteNode(Symbol("x$f")) for f in fieldnames(T)]...)))
79+
function Base.setindex!(s::StructArray, vals, I::Int...)
80+
@boundscheck checkbounds(s, I...)
81+
@inbounds foreachcolumn((col, val) -> (col[I...] = val), s, vals)
82+
s
8383
end
8484

8585
@inline getfieldindex(v::Tuple, field::Symbol, index::Integer) = getfield(v, index)
8686
@inline getfieldindex(v, field::Symbol, index::Integer) = getproperty(v, field)
8787

88-
@generated function Base.push!(s::StructArray{T, 1}, vals) where {T}
89-
exprs = foreach_expr((args...) -> Expr(:call, :push!, args...), T, :s, :vals)
90-
push!(exprs, :s)
91-
Expr(:block, exprs...)
88+
function Base.push!(s::StructArray, vals)
89+
foreachcolumn(push!, s, vals)
9290
end
9391

94-
@generated function Base.append!(s::StructArray{T, 1}, vals) where {T}
95-
exprs = foreach_expr((args...) -> Expr(:call, :append!, args...), T, :s, :vals)
96-
push!(exprs, :s)
97-
Expr(:block, exprs...)
92+
function Base.append!(s::StructArray, vals)
93+
foreachcolumn(append!, s, vals)
9894
end
9995

10096
function Base.cat(args::StructArray...; dims)
@@ -125,4 +121,3 @@ Base.copy(s::StructArray{T,N,C}) where {T,N,C} = StructArray{T,N,C}(C(copy(x) fo
125121
function Base.reshape(s::StructArray{T}, d::Dims) where {T}
126122
StructArray{T}(map(x -> reshape(x, d), columns(s)))
127123
end
128-

src/tables.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ Tables.columnaccess(::Type{<:StructArray}) = true
55
Tables.rows(s::StructArray) = s
66
Tables.columns(s::StructArray) = columns(s)
77

8-
@generated function Tables.schema(s::StructArray{T}) where {T}
9-
names = fieldnames(T)
10-
types = map(sym -> fieldtype(T, sym), names)
11-
:(Tables.Schema($names, $types))
8+
function Tables.schema(s::StructArray{T}) where {T}
9+
NT = staticschema(T)
10+
names = getnames(NT)
11+
types = gettypes(NT).parameters
12+
Tables.Schema(names, types)
1213
end

src/utils.jl

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,39 @@
11
import Base: tuple_type_cons, tuple_type_head, tuple_type_tail, tail
22

3-
eltypes(::Type{Tuple{}}) = Tuple{}
4-
eltypes(::Type{T}) where {T<:Tuple} =
5-
tuple_type_cons(eltype(tuple_type_head(T)), eltypes(tuple_type_tail(T)))
6-
eltypes(::Type{NamedTuple{K, V}}) where {K, V} = NamedTuple{K, eltypes(V)}
3+
eltypes(::Type{T}) where {T} = map_types(eltype, T)
74

8-
Base.@pure SkipConstructor(::Type) = false
5+
map_types(f, ::Type{Tuple{}}) = Tuple{}
6+
function map_types(f, ::Type{T}) where {T<:Tuple}
7+
tuple_type_cons(f(tuple_type_head(T)), map_types(f, tuple_type_tail(T)))
8+
end
9+
map_types(f, ::Type{NamedTuple{names, types}}) where {names, types} =
10+
NamedTuple{names, map_types(f, types)}
911

10-
function foreach_expr(f, T, args...)
11-
exprs = []
12-
for (ind, key) in enumerate(fields(T))
13-
new_args = (Expr(:call, :getfieldindex, arg, Expr(:quote, key), ind) for arg in args)
14-
push!(exprs, f(new_args...))
15-
end
16-
exprs
12+
map_params(f, ::Type{Tuple{}}) = ()
13+
function map_params(f, ::Type{T}) where {T<:Tuple}
14+
(f(tuple_type_head(T)), map_params(f, tuple_type_tail(T))...)
1715
end
16+
map_params(f, ::Type{NamedTuple{names, types}}) where {names, types} =
17+
NamedTuple{names}(map_params(f, types))
1818

19-
@generated function get_ith(s::StructArray{T}, I...) where {T}
20-
exprs = foreach_expr(field -> :($field[I...]), T, :s)
21-
return quote
22-
@boundscheck checkbounds(s, I...)
23-
@inbounds $(Expr(:call, :createinstance, :T, exprs...))
24-
end
19+
buildfromschema(initializer, ::Type{T}) where {T} = buildfromschema(initializer, T, staticschema(T))
20+
21+
function buildfromschema(initializer, ::Type{T}, ::Type{NT}) where {T, NT<:NamedTuple}
22+
nt = map_params(initializer, NT)
23+
StructArray{T}(nt)
2524
end
2625

27-
@generated function set_ith!(s::StructArray{T}, vals, I...) where {T}
28-
exprs = foreach_expr((field, val) -> :($field[I...] = $val), T, :s, :vals)
29-
push!(exprs, :s)
30-
return quote
31-
@boundscheck checkbounds(s, I...)
32-
@inbounds $(Expr(:block, exprs...))
26+
Base.@pure SkipConstructor(::Type) = false
27+
28+
@generated function foreachcolumn(f, x::StructArray{T, N, NamedTuple{names, types}}, xs...) where {T, N, names, types}
29+
exprs = Expr[]
30+
for (i, field) in enumerate(names)
31+
sym = QuoteNode(field)
32+
args = [Expr(:call, :getfieldindex, :(getfield(xs, $j)), sym, i) for j in 1:length(xs)]
33+
push!(exprs, Expr(:call, :f, Expr(:., :x, sym), args...))
3334
end
35+
push!(exprs, :(return nothing))
36+
Expr(:block, exprs...)
3437
end
3538

3639
function createinstance(::Type{T}, args...) where {T}
@@ -46,7 +49,7 @@ createinstance(::Type{T}, args...) where {T<:Union{Tuple, NamedTuple}} = T(args)
4649
Expr(:block, new_tup, construct)
4750
end
4851

49-
createtype(::Type{T}, ::Type{NamedTuple{names, types}}) where {T, names, types} = createtype(T, names, eltypes(types))
52+
createtype(::Type{T}, ::Type{NamedTuple{names, types}}) where {T, names, types} = createtype(T, names, eltypes(types))
5053

5154
createtype(::Type{T}, names, types) where {T} = T
5255
createtype(::Type{T}, names, types) where {T<:Tuple} = types

test/runtests.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,14 @@ end
102102
end
103103

104104
f_infer() = StructArray{ComplexF64}(rand(2,2), rand(2,2))
105-
unwrap(::Type) = false
106-
unwrap(::Type{<:NamedTuple}) = true
107105

108-
g_infer() = StructArray([(a=(b=1,), c=2)], unwrap = unwrap)
106+
g_infer() = StructArray([(a=(b="1",), c=2)], unwrap = t -> t <: NamedTuple)
109107
tup_infer() = StructArray([(1, 2), (3, 4)])
110108

111109
@testset "inferrability" begin
112110
@inferred f_infer()
113111
@inferred g_infer()
114-
@test g_infer().a.b == [1]
112+
@test g_infer().a.b == ["1"]
115113
s = @inferred tup_infer()
116114
@test Tables.columns(s) == (x1 = [1, 3], x2 = [2, 4])
117115
@test s[1] == (1, 2)
@@ -150,7 +148,7 @@ StructArrays.SkipConstructor(::Type{<:S}) = true
150148
end
151149

152150
const initializer = StructArrays.ArrayInitializer(t -> t <: Union{Tuple, NamedTuple, Pair})
153-
collect_columns(t) = StructArrays.collect_columns(t, initializer = initializer)
151+
collect_columns(t) = StructArrays.collect_columns(t, initializer = initializer)
154152

155153
@testset "collectnamedtuples" begin
156154
v = [(a = 1, b = 2), (a = 1, b = 3)]

0 commit comments

Comments
 (0)