Skip to content

Commit cb3b187

Browse files
committed
Make StructArrayStyle track inputs dimension
fix #185
1 parent 8e67e4e commit cb3b187

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

src/structarray.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,15 @@ end
445445
# broadcast
446446
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
447447

448-
struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end
448+
struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end
449+
450+
# Here we define the dimension tracking behaviour of StructArrayStyle
451+
function StructArrayStyle{S,M}(::Val{N}) where {S,M,N}
452+
if S <: AbstractArrayStyle{M}
453+
return StructArrayStyle{typeof(S(Val(N))),N}()
454+
end
455+
return StructArrayStyle{S,N}()
456+
end
449457

450458
@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
451459
combine_style_types(BroadcastStyle(A), args...)
@@ -455,9 +463,9 @@ combine_style_types(s::BroadcastStyle) = s
455463

456464
Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...)
457465

458-
BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}()
466+
BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)),ndims(SA)}()
459467

460-
Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} =
468+
Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,ElType} =
461469
isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))
462470

463471
# for aliasing analysis during broadcast

test/runtests.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,24 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
926926
# used inside of broadcast but we also test it here explicitly
927927
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
928928

929-
s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2))))
929+
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
930930
@test_throws MethodError s .+ s
931+
932+
# test for dimensionality track
933+
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
934+
@test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
935+
@test Base.broadcasted(+, s, [1;;2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
936+
@test Base.broadcasted(+, [1;;;2], s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
937+
938+
a = StructArray([1;2+im])
939+
b = StructArray([1;;2+im])
940+
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
941+
942+
# issue #185
943+
A = StructArray(randn(ComplexF64, 3, 3))
944+
B = randn(ComplexF64, 3, 3)
945+
c = StructArray(randn(ComplexF64, 3))
946+
@test (A .= B .* c) === A
931947
end
932948

933949
@testset "staticarrays" begin

0 commit comments

Comments
 (0)