@@ -28,29 +28,25 @@ StructArray(; kwargs...) = StructArray(values(kwargs))
28
28
StructArray {T} (args... ) where {T} = StructArray {T} (NamedTuple {fields(T)} (args))
29
29
30
30
_undef_array (:: Type{T} , sz; unwrap = t -> false ) where {T} = unwrap (T) ? StructArray {T} (undef, sz; unwrap = unwrap) : Array {T} (undef, sz)
31
- function _similar (v:: S , :: Type{Z} ; unwrap = t -> false ) where {S <: AbstractArray{T, N} , Z} where {T, N}
32
- unwrap (Z) ? StructArray {Z} (map (t -> _similar (v, fieldtype (Z, t); unwrap = unwrap), fields (Z))) : similar (v, Z)
33
- end
34
31
32
+ _similar (v:: AbstractArray , :: Type{Z} ; unwrap = t -> false ) where {Z} =
33
+ unwrap (Z) ? buildfromschema (typ -> _similar (v, typ; unwrap = unwrap), Z) : similar (v, Z)
34
+
35
+ function StructArray {T} (:: Base.UndefInitializer , sz:: Dims ; unwrap = t -> false ) where {T}
36
+ buildfromschema (typ -> _undef_array (typ, sz; unwrap = unwrap), T)
37
+ end
35
38
StructArray {T} (u:: Base.UndefInitializer , d:: Integer... ; unwrap = t -> false ) where {T} = StructArray {T} (u, convert (Dims, d); unwrap = unwrap)
36
- @generated function StructArray {T} (:: Base.UndefInitializer , sz:: Dims ; unwrap = t -> false ) where {T}
37
- ex = Expr (:tuple , [:(_undef_array ($ (fieldtype (T, i)), sz; unwrap = unwrap)) for i in 1 : fieldcount (T)]. .. )
38
- return quote
39
- StructArray {T} (NamedTuple {fields(T)} ($ ex))
40
- end
39
+
40
+ function similar_structarray (v:: AbstractArray , :: Type{Z} ; unwrap = t -> false ) where {Z}
41
+ buildfromschema (typ -> _similar (v, typ; unwrap = unwrap), Z)
41
42
end
42
43
43
- @generated function StructArray (v:: AbstractArray{T, N} ; unwrap = t -> false ) where {T, N}
44
- syms = [gensym () for i in 1 : fieldcount (T)]
45
- init = Expr (:block , [:($ (syms[i]) = _similar (v, $ (fieldtype (T, i)); unwrap = unwrap)) for i in 1 : fieldcount (T)]. .. )
46
- push = Expr (:block , [:($ (syms[i])[j] = getfield (f, $ i)) for i in 1 : fieldcount (T)]. .. )
47
- quote
48
- $ init
49
- for (j, f) in enumerate (v)
50
- @inbounds $ push
51
- end
52
- return StructArray {T} ($ (syms... ))
44
+ function StructArray (v:: AbstractArray{T} ; unwrap = t -> false ) where {T}
45
+ s = similar_structarray (v, T; unwrap = unwrap)
46
+ for i in eachindex (v)
47
+ @inbounds s[i] = v[i]
53
48
end
49
+ s
54
50
end
55
51
StructArray (s:: StructArray ) = copy (s)
56
52
@@ -63,7 +59,15 @@ Base.propertynames(s::StructArray) = fieldnames(typeof(columns(s)))
63
59
64
60
Base. size (s:: StructArray ) = size (columns (s)[1 ])
65
61
66
- Base. @propagate_inbounds Base. getindex (s:: StructArray , I:: Int... ) = get_ith (s, I... )
62
+ @generated function Base. getindex (x:: StructArray{T, N, NamedTuple{names, types}} , I:: Int... ) where {T, N, names, types}
63
+ args = [:(getfield (cols, $ i)[I... ]) for i in 1 : length (names)]
64
+ return quote
65
+ cols = columns (x)
66
+ @boundscheck checkbounds (x, I... )
67
+ @inbounds $ (Expr (:call , :createinstance , :T , args... ))
68
+ end
69
+ end
70
+
67
71
function Base. getindex (s:: StructArray{T, N, C} , I:: Union{Int, AbstractArray, Colon} ...) where {T, N, C}
68
72
StructArray {T} (map (v -> getindex (v, I... ), columns (s)))
69
73
end
@@ -72,29 +76,21 @@ function Base.view(s::StructArray{T, N, C}, I...) where {T, N, C}
72
76
StructArray {T} (map (v -> view (v, I... ), columns (s)))
73
77
end
74
78
75
- Base. @propagate_inbounds Base. setindex! (s:: StructArray , val, I:: Int... ) = set_ith! (s, val, I... )
76
-
77
- fields (:: Type{<:NamedTuple{K}} ) where {K} = K
78
- @generated function fields (t:: Type{T} ) where {T}
79
- return :($ (Expr (:tuple , [QuoteNode (f) for f in fieldnames (T)]. .. )))
80
- end
81
- @generated function fields (t:: Type{T} ) where {T<: Tuple }
82
- return :($ (Expr (:tuple , [QuoteNode (Symbol (" x$f " )) for f in fieldnames (T)]. .. )))
79
+ function Base. setindex! (s:: StructArray , vals, I:: Int... )
80
+ @boundscheck checkbounds (s, I... )
81
+ @inbounds foreachcolumn ((col, val) -> (col[I... ] = val), s, vals)
82
+ s
83
83
end
84
84
85
85
@inline getfieldindex (v:: Tuple , field:: Symbol , index:: Integer ) = getfield (v, index)
86
86
@inline getfieldindex (v, field:: Symbol , index:: Integer ) = getproperty (v, field)
87
87
88
- @generated function Base. push! (s:: StructArray{T, 1} , vals) where {T}
89
- exprs = foreach_expr ((args... ) -> Expr (:call , :push! , args... ), T, :s , :vals )
90
- push! (exprs, :s )
91
- Expr (:block , exprs... )
88
+ function Base. push! (s:: StructArray , vals)
89
+ foreachcolumn (push!, s, vals)
92
90
end
93
91
94
- @generated function Base. append! (s:: StructArray{T, 1} , vals) where {T}
95
- exprs = foreach_expr ((args... ) -> Expr (:call , :append! , args... ), T, :s , :vals )
96
- push! (exprs, :s )
97
- Expr (:block , exprs... )
92
+ function Base. append! (s:: StructArray , vals)
93
+ foreachcolumn (append!, s, vals)
98
94
end
99
95
100
96
function Base. cat (args:: StructArray... ; dims)
@@ -125,4 +121,3 @@ Base.copy(s::StructArray{T,N,C}) where {T,N,C} = StructArray{T,N,C}(C(copy(x) fo
125
121
function Base. reshape (s:: StructArray{T} , d:: Dims ) where {T}
126
122
StructArray {T} (map (x -> reshape (x, d), columns (s)))
127
123
end
128
-
0 commit comments