Skip to content

Commit 3c4b7f2

Browse files
committed
Fix regression on StaticArray
Make sure `StructArrayStyle{<:StaticArrayStyle}` lose to `DefaultArrayStyle`
1 parent c7a4dc8 commit 3c4b7f2

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

src/staticarrays_support.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import StaticArrays: StaticArray, FieldArray, tuple_prod
1+
import StaticArrays: StaticArray, FieldArray, tuple_prod, StaticArrayStyle
22

33
"""
44
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
@@ -26,4 +26,9 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
2626
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
2727
end
2828
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
29-
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
29+
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
30+
31+
function Base.copy(bc::Broadcasted{StructArrayStyle{StaticArrayStyle{N},N}}) where {N}
32+
B = convert(Broadcasted{StructArrayStyle{Broadcast.DefaultArrayStyle{N},N}}, bc)
33+
copy(B)
34+
end

src/structarray.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
443443
end
444444

445445
# broadcast
446-
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted
446+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
447447

448448
struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end
449449

@@ -468,6 +468,7 @@ function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S,N}) wher
468468
end
469469
StructArrayStyle{typeof(S′),_dimmax(N,M)}()
470470
end
471+
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
471472

472473
@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
473474
combine_style_types(BroadcastStyle(A), args...)

test/runtests.jl

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,7 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{E
917917
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
918918
MyArray2{ElType}(undef, size(bc))
919919
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
920+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S
920921

921922
@testset "broadcast" begin
922923
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
@@ -935,27 +936,24 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA
935936
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
936937

937938
# Make sure we can handle style with similar defined
939+
# And we can handle most conflict
938940
# s1 and s2 has similar defined, but s3 not
939-
# s2 are conflict with s1 and s3.
941+
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
940942
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
941943
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
942944
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
945+
s4 = StructArray{ComplexF64}((rand(2), rand(2)))
943946

944-
function _test_similar(a, b)
945-
flag = false
947+
function _test_similar(a, b, c)
946948
try
947-
c = StructArray{ComplexF64}((a.re .+ b.re, a.im .+ b.im))
948-
flag = true
949+
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
950+
@test typeof(a .+ b .- c) == typeof(d)
949951
catch
950-
end
951-
if flag
952-
@test typeof(@inferred(a .+ b)) == typeof(c)
953-
else
954-
@test_throws MethodError a .+ b
952+
@test_throws MethodError a .+ b .- c
955953
end
956954
end
957-
for s in (s1,s2,s3), s′ in (s1,s2,s3)
958-
_test_similar(s, s′)
955+
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
956+
_test_similar(s, s′, s″)
959957
end
960958

961959
# test for dimensionality track
@@ -973,20 +971,26 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA
973971
@test (A .= B .* c) === A
974972

975973
# ambiguity check (can we do this better?)
976-
function _test(a, b)
977-
if a isa StructArray || b isa StructArray
978-
d = @inferred a .+ b
979-
@test d == collect(a) .+ collect(b)
974+
function _test(a, b, c)
975+
if a isa StructArray || b isa StructArray || c isa StructArray
976+
d = @inferred a .+ b .- c
977+
@test d == collect(a) .+ collect(b) .- collect(c)
980978
@test d isa StructArray
981979
end
982980
end
983-
testset = StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), (@SArray [1 2])
984-
for aa in testset, bb in testset
985-
_test(aa, bb)
981+
testset = (StructArray([1;2+im]),
982+
StructArray([1 2+im]),
983+
1:2,
984+
(1,2),
985+
(@SArray [1 2]),
986+
StructArray(@SArray [1 1+2im]))
987+
for aa in testset, bb in testset, cc in testset
988+
_test(aa, bb, cc)
986989
end
987-
a = StructArray([1;2+im])
988-
b = StructArray([1 2+im])
989-
@test @inferred(a .+ b .+ a .* a' .+ (1,2) .+ (1:2) .- b') isa StructArray
990+
991+
a = @SArray randn(3,3);
992+
b = StructArray{ComplexF64}((a,a))
993+
@test a[:,1] .+ b isa StructArray && (a[:,1] .+ b).re isa SizedMatrix
990994
end
991995

992996
@testset "staticarrays" begin

0 commit comments

Comments
 (0)