Skip to content

Commit c7a4dc8

Browse files
committed
Try to resolve style conflict and extend similar
1 parent 9b9d8b2 commit c7a4dc8

File tree

2 files changed

+97
-25
lines changed

2 files changed

+97
-25
lines changed

src/structarray.jl

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,31 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
443443
end
444444

445445
# broadcast
446-
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
446+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted
447447

448448
struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end
449-
# If `S` also track input's dimensionality, we'd better also update it.
450-
StructArrayStyle{S,M}(::Val{N}) where {M,S<:AbstractArrayStyle{M},N} =
451-
StructArrayStyle{typeof(S(Val(N))),N}()
452-
StructArrayStyle{S,M}(::Val{N}) where {M,S,N} = StructArrayStyle{S,N}()
449+
450+
# Here we define the dimension tracking behaviour of StructArrayStyle
451+
function StructArrayStyle{S,M}(::Val{N}) where {S,M,N}
452+
if S <: AbstractArrayStyle{M}
453+
return StructArrayStyle{typeof(S(Val(N))),N}()
454+
end
455+
return StructArrayStyle{S,N}()
456+
end
457+
458+
_dimmax(a::Integer, b::Integer) = max(a, b)
459+
_dimmax(::Type{Any}, ::Integer) = Any
460+
_dimmax(::Integer ,::Type{Any}) = Any
461+
462+
# StructArrayStyle is a wrapped style.
463+
# Here we try our best to resolve style conflict.
464+
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S,N}) where {S,N,M}
465+
S′ = Broadcast.result_style(S(), b)
466+
if S′ isa StructArrayStyle # avoid double wrap
467+
return typeof(S′)(Val(_dimmax(N,M)))
468+
end
469+
StructArrayStyle{typeof(S′),_dimmax(N,M)}()
470+
end
453471

454472
@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
455473
combine_style_types(BroadcastStyle(A), args...)
@@ -461,8 +479,19 @@ Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parame
461479

462480
BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)),ndims(SA)}()
463481

464-
Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} =
465-
isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))
482+
# Here we use `similar` defined for `S` to build the dest Array.
483+
function Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S,ElType}
484+
bc′ = convert(Broadcasted{S}, bc)
485+
if isstructtype(ElType)
486+
return buildfromschema(T -> similar(bc′, T), ElType)
487+
end
488+
return similar(bc′, ElType)
489+
end
490+
491+
# Unwrapper the style to recover the behaviour defined by style.
492+
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle{S}}) where {S}
493+
return copyto!(dest, convert(Broadcasted{S}, bc))
494+
end
466495

467496
# for aliasing analysis during broadcast
468497
Base.dataids(u::StructArray) = mapreduce(Base.dataids, (a, b) -> (a..., b...), values(components(u)), init=())

test/runtests.jl

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -898,17 +898,25 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
898898
@test t.b.d isa Array
899899
end
900900

901-
struct MyArray{T,N} <: AbstractArray{T,N}
902-
A::Array{T,N}
903-
end
904-
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
905-
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
906-
Base.getindex(A::MyArray, i::Int) = A.A[i]
907-
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
908-
Base.size(A::MyArray) = Base.size(A.A)
909-
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
910-
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
911-
MyArray{ElType}(undef, size(bc))
901+
for S in (1, 2, 3)
902+
MyArray = Symbol(:MyArray, S)
903+
@eval begin
904+
struct $MyArray{T,N} <: AbstractArray{T,N}
905+
A::Array{T,N}
906+
end
907+
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
908+
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
909+
Base.getindex(A::$MyArray, i::Int) = A.A[i]
910+
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
911+
Base.size(A::$MyArray) = Base.size(A.A)
912+
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
913+
end
914+
end
915+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
916+
MyArray1{ElType}(undef, size(bc))
917+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
918+
MyArray2{ElType}(undef, size(bc))
919+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
912920

913921
@testset "broadcast" begin
914922
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
@@ -926,24 +934,59 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
926934
# used inside of broadcast but we also test it here explicitly
927935
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
928936

929-
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
930-
@test_throws MethodError s .+ s
937+
# Make sure we can handle style with similar defined
938+
# s1 and s2 has similar defined, but s3 not
939+
# s2 are conflict with s1 and s3.
940+
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
941+
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
942+
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
943+
944+
function _test_similar(a, b)
945+
flag = false
946+
try
947+
c = StructArray{ComplexF64}((a.re .+ b.re, a.im .+ b.im))
948+
flag = true
949+
catch
950+
end
951+
if flag
952+
@test typeof(@inferred(a .+ b)) == typeof(c)
953+
else
954+
@test_throws MethodError a .+ b
955+
end
956+
end
957+
for s in (s1,s2,s3), s′ in (s1,s2,s3)
958+
_test_similar(s, s′)
959+
end
931960

932961
# test for dimensionality track
962+
s = s1
933963
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
934-
@test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
964+
@test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
935965
@test Base.broadcasted(+, s, [1;;2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
936966
@test Base.broadcasted(+, [1;;;2], s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
937-
938-
a = StructArray([1;2+im])
939-
b = StructArray([1;;2+im])
940-
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
967+
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
941968

942969
# issue #185
943970
A = StructArray(randn(ComplexF64, 3, 3))
944971
B = randn(ComplexF64, 3, 3)
945972
c = StructArray(randn(ComplexF64, 3))
946973
@test (A .= B .* c) === A
974+
975+
# ambiguity check (can we do this better?)
976+
function _test(a, b)
977+
if a isa StructArray || b isa StructArray
978+
d = @inferred a .+ b
979+
@test d == collect(a) .+ collect(b)
980+
@test d isa StructArray
981+
end
982+
end
983+
testset = StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), (@SArray [1 2])
984+
for aa in testset, bb in testset
985+
_test(aa, bb)
986+
end
987+
a = StructArray([1;2+im])
988+
b = StructArray([1 2+im])
989+
@test @inferred(a .+ b .+ a .* a' .+ (1,2) .+ (1:2) .- b') isa StructArray
947990
end
948991

949992
@testset "staticarrays" begin

0 commit comments

Comments
 (0)