Skip to content

Commit 0c4adf6

Browse files
timholyKeno
andauthored
Make StructArrays broadcast aware (#136)
* Make StructArrays broadcast aware Fixes #89 * Limit circumstances in which broadcasting returns a StructArray While allowing broadcasting to return a StructArray, this limits it to cases where: - no other arrays in the broadcast operation, including those wrapped by the StructArray, have non-default BroadcastStyle - the eltype returned from the function is a struct type It should be straightforward to define precedence rules to handle other cases, e.g., StructArrays of CuArrays. * Add a test for custom-broadcasting internal arrays * Embrace the MethodError Co-authored-by: Keno Fischer <keno@alumni.harvard.edu>
1 parent 7b77672 commit 0c4adf6

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

src/structarray.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@ index_type(::Type{NamedTuple{names, types}}) where {names, types} = index_type(t
2222
index_type(::Type{Tuple{}}) = Int
2323
function index_type(::Type{T}) where {T<:Tuple}
2424
S, U = tuple_type_head(T), tuple_type_tail(T)
25-
IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U)
25+
IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U)
2626
end
2727

2828
index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I
2929

30+
array_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_types(C)
31+
array_types(::Type{NamedTuple{names, types}}) where {names, types} = types
32+
array_types(::Type{TT}) where {TT<:Tuple} = TT
33+
3034
function StructArray{T}(c::C) where {T, C<:Tup}
3135
cols = strip_params(staticschema(T))(c)
32-
N = isempty(cols) ? 1 : ndims(cols[1])
36+
N = isempty(cols) ? 1 : ndims(cols[1])
3337
StructArray{T, N, typeof(cols)}(cols)
3438
end
3539

@@ -225,3 +229,21 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
225229
showfields(io, Tuple(fieldarrays(s)))
226230
toplevel && print(io, " with eltype ", T)
227231
end
232+
233+
# broadcast
234+
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
235+
236+
struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end
237+
238+
@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
239+
combine_style_types(BroadcastStyle(A), args...)
240+
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where A<:AbstractArray =
241+
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
242+
combine_style_types(s::BroadcastStyle) = s
243+
244+
Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...)
245+
246+
BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}()
247+
248+
Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} =
249+
isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))

test/runtests.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,3 +714,32 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
714714
@test t.b.c isa Array
715715
@test t.b.d isa Array
716716
end
717+
718+
struct MyArray{T,N} <: AbstractArray{T,N}
719+
A::Array{T,N}
720+
end
721+
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
722+
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
723+
Base.getindex(A::MyArray, i::Int) = A.A[i]
724+
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
725+
Base.size(A::MyArray) = Base.size(A.A)
726+
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
727+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
728+
MyArray{ElType}(undef, size(bc))
729+
730+
@testset "broadcast" begin
731+
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
732+
@test isa(@inferred(s .+ s), StructArray)
733+
@test (s .+ s).re == 2*s.re
734+
@test (s .+ s).im == 2*s.im
735+
@test isa(@inferred(broadcast(t->1, s)), Array)
736+
@test all(x->x==1, broadcast(t->1, s))
737+
@test isa(@inferred(s .+ 1), StructArray)
738+
@test s .+ 1 == StructArray{ComplexF64}((s.re .+ 1, s.im))
739+
r = rand(2,2)
740+
@test isa(@inferred(s .+ r), StructArray)
741+
@test s .+ r == StructArray{ComplexF64}((s.re .+ r, s.im))
742+
743+
s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2))))
744+
@test_throws MethodError s .+ s
745+
end

0 commit comments

Comments
 (0)