Skip to content

Commit 3783f6d

Browse files
authored
implement undef init and various other fixes (JuliaArrays#7)
1 parent e6b46a7 commit 3783f6d

File tree

3 files changed

+34
-25
lines changed

3 files changed

+34
-25
lines changed

src/StructArrays.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
module StructArrays
22

3-
import Base:
4-
getindex, setindex!, size, push!, view, getproperty, append!, cat, vcat, hcat
5-
# linearindexing, push!, size, sort, sort!, permute!, issorted, sortperm,
6-
# summary, resize!, vcat, serialize, deserialize, append!, copy!, view
7-
83
export StructArray
94

105
const Tup = Union{Tuple, NamedTuple}

src/structarray.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,34 +23,39 @@ StructArray{T}(c::C) where {T, C<:NamedTuple} =
2323
StructArray(c::C) where {C<:NamedTuple} = StructArray{C}(c)
2424

2525
StructArray{T}(args...) where {T} = StructArray{T}(NamedTuple{fields(T)}(args))
26+
@generated function StructArray{T}(::Base.UndefInitializer, d::Integer...) where {T}
27+
ex = Expr(:tuple, [:(Array{$(fieldtype(T, i))}(undef, sz)) for i in 1:fieldcount(T)]...)
28+
return quote
29+
sz = convert(Tuple{Vararg{Int}}, d)
30+
StructArray{T}(NamedTuple{fields(T)}($ex))
31+
end
32+
end
2633

2734
columns(s::StructArray) = getfield(s, :columns)
28-
getproperty(s::StructArray, key::Symbol) = getfield(columns(s), key)
29-
getproperty(s::StructArray, key::Int) = getfield(columns(s), key)
35+
Base.getproperty(s::StructArray, key::Symbol) = getfield(columns(s), key)
36+
Base.getproperty(s::StructArray, key::Int) = getfield(columns(s), key)
37+
Base.propertynames(s::StructArray) = fieldnames(typeof(columns(s)))
3038

31-
size(s::StructArray) = size(columns(s)[1])
39+
Base.size(s::StructArray) = size(columns(s)[1])
3240

33-
getindex(s::StructArray, I::Int...) = get_ith(s, I...)
34-
function getindex(s::StructArray{T, N, C}, I::Union{Int, AbstractArray, Colon}...) where {T, N, C}
41+
Base.getindex(s::StructArray, I::Int...) = get_ith(s, I...)
42+
function Base.getindex(s::StructArray{T, N, C}, I::Union{Int, AbstractArray, Colon}...) where {T, N, C}
3543
StructArray{T}(map(v -> getindex(v, I...), columns(s)))
3644
end
3745

38-
function view(s::StructArray{T, N, C}, I...) where {T, N, C}
46+
function Base.view(s::StructArray{T, N, C}, I...) where {T, N, C}
3947
StructArray{T}(map(v -> view(v, I...), columns(s)))
4048
end
4149

42-
setindex!(s::StructArray, val, I::Int...) = set_ith!(s, val, I...)
50+
Base.setindex!(s::StructArray, val, I::Int...) = set_ith!(s, val, I...)
4351

4452
fields(::Type{<:NamedTuple{K}}) where {K} = K
4553
fields(::Type{<:StructArray{T}}) where {T} = fields(T)
46-
47-
Base.propertynames(s::StructArray) = fieldnames(typeof(columns(s)))
48-
4954
@generated function fields(t::Type{T}) where {T}
5055
return :($(Expr(:tuple, [QuoteNode(f) for f in fieldnames(T)]...)))
5156
end
5257

53-
@generated function push!(s::StructArray{T, 1}, vals) where {T}
58+
@generated function Base.push!(s::StructArray{T, 1}, vals) where {T}
5459
args = []
5560
for key in fields(T)
5661
field = Expr(:., :s, Expr(:quote, key))
@@ -61,7 +66,7 @@ end
6166
Expr(:block, args...)
6267
end
6368

64-
@generated function append!(s::StructArray{T, 1}, vals) where {T}
69+
@generated function Base.append!(s::StructArray{T, 1}, vals) where {T}
6570
args = []
6671
for key in fields(T)
6772
field = Expr(:., :s, Expr(:quote, key))
@@ -72,8 +77,8 @@ end
7277
Expr(:block, args...)
7378
end
7479

75-
function cat(dims, args::StructArray...)
76-
f = key -> cat(dims, (getproperty(t, key) for t in args)...)
80+
function Base.cat(args::StructArray...; dims)
81+
f = key -> cat((getproperty(t, key) for t in args)...; dims=dims)
7782
T = mapreduce(eltype, promote_type, args)
7883
StructArray{T}(map(f, fields(eltype(args[1]))))
7984
end
@@ -87,7 +92,7 @@ end
8792

8893
for op in [:hcat, :vcat]
8994
@eval begin
90-
function $op(args::StructArray...)
95+
function Base.$op(args::StructArray...)
9196
f = key -> $op((getproperty(t, key) for t in args)...)
9297
T = mapreduce(eltype, promote_type, args)
9398
StructArray{T}(map(f, fields(eltype(args[1]))))

test/runtests.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ using Test
55
@testset "index" begin
66
a, b = [1 2; 3 4], [4 5; 6 7]
77
t = StructArray((a = a, b = b))
8-
@test t[2,2] == (a = 4, b = 7)
9-
@test t[2,1:2] == StructArray((a = [3, 4], b = [6, 7]))
10-
@test view(t, 2, 1:2) == StructArray((a = view(a, 2, 1:2), b = view(b, 2, 1:2)))
8+
@test (@inferred t[2,2]) == (a = 4, b = 7)
9+
@test (@inferred t[2,1:2]) == StructArray((a = [3, 4], b = [6, 7]))
10+
@test (@inferred view(t, 2, 1:2)) == StructArray((a = view(a, 2, 1:2), b = view(b, 2, 1:2)))
1111
end
1212

1313
@testset "complex" begin
@@ -18,6 +18,15 @@ end
1818
@test view(t, 2, 1:2) == StructArray{ComplexF64}(view(a, 2, 1:2), view(b, 2, 1:2))
1919
end
2020

21+
@testset "undef initializer" begin
22+
t = @inferred StructArray{ComplexF64}(undef, 5, 5)
23+
@test eltype(t) == ComplexF64
24+
@test size(t) == (5,5)
25+
c = 2 + im
26+
t[1,1] = c
27+
@test t[1,1] == c
28+
end
29+
2130
@testset "resize!" begin
2231
t = StructArray{Pair}([3, 5], ["a", "b"])
2332
resize!(t, 5)
@@ -35,9 +44,9 @@ end
3544
@test t == StructArray{Pair}([3, 5, 2, 3, 5, 2], ["a", "b", "c", "a", "b", "c"])
3645
t = StructArray{Pair}([3, 5], ["a", "b"])
3746
t2 = StructArray{Pair}([1, 6], ["a", "b"])
38-
@test cat(t, t2; dims=1) == StructArray{Pair}([3, 5, 1, 6], ["a", "b", "a", "b"]) == vcat(t, t2)
47+
@test cat(t, t2; dims=1)::StructArray == StructArray{Pair}([3, 5, 1, 6], ["a", "b", "a", "b"]) == vcat(t, t2)
3948
@test vcat(t, t2) isa StructArray
40-
@test cat(t, t2; dims=2) == StructArray{Pair}([3 1; 5 6], ["a" "a"; "b" "b"]) == hcat(t, t2)
49+
@test cat(t, t2; dims=2)::StructArray == StructArray{Pair}([3 1; 5 6], ["a" "a"; "b" "b"]) == hcat(t, t2)
4150
@test hcat(t, t2) isa StructArray
4251
end
4352

0 commit comments

Comments
 (0)