Skip to content

Commit 47f122c

Browse files
committed
Try to resolve style conflict.
1 parent efda46e commit 47f122c

File tree

2 files changed

+82
-21
lines changed

2 files changed

+82
-21
lines changed

src/structarray.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
485485
end
486486

487487
# broadcast
488-
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
488+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
489489

490490
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
491491

@@ -495,19 +495,39 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
495495
return StructArrayStyle{T, N}()
496496
end
497497

498+
# StructArrayStyle is a wrapped style.
499+
# Here we try our best to resolve style conflict.
500+
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
501+
N′ = M === Any || N === Any ? Any : max(M, N)
502+
S′ = Broadcast.result_style(S(), b)
503+
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
504+
end
505+
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
506+
498507
@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
499508
combine_style_types(BroadcastStyle(A), args...)
500509
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
501510
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
511+
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
502512
combine_style_types(s::BroadcastStyle) = s
503513

504514
Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)
505515

506516
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
507517

508-
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:Union{DefaultArrayStyle,StructArrayStyle}, N, ElType}
509-
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
510-
return similar(ContainerType, axes(bc))
518+
# Here we use `similar` defined for `S` to build the dest Array.
519+
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
520+
bc′ = convert(Broadcasted{S}, bc)
521+
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
522+
end
523+
524+
# Unwrapper to recover the behaviour defined by parent style.
525+
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
526+
return copyto!(dest, convert(Broadcasted{S}, bc))
527+
end
528+
529+
@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
530+
return Broadcast.materialize!(S(), dest, bc)
511531
end
512532

513533
# for aliasing analysis during broadcast

test/runtests.jl

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,17 +1090,26 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
10901090
@test t.b.d isa Array
10911091
end
10921092

1093-
struct MyArray{T,N} <: AbstractArray{T,N}
1094-
A::Array{T,N}
1093+
for S in (1, 2, 3)
1094+
MyArray = Symbol(:MyArray, S)
1095+
@eval begin
1096+
struct $MyArray{T,N} <: AbstractArray{T,N}
1097+
A::Array{T,N}
1098+
end
1099+
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
1100+
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
1101+
Base.getindex(A::$MyArray, i::Int) = A.A[i]
1102+
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
1103+
Base.size(A::$MyArray) = Base.size(A.A)
1104+
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
1105+
end
10951106
end
1096-
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
1097-
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
1098-
Base.getindex(A::MyArray, i::Int) = A.A[i]
1099-
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
1100-
Base.size(A::MyArray) = Base.size(A.A)
1101-
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
1102-
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
1103-
MyArray{ElType}(undef, size(bc))
1107+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
1108+
MyArray1{ElType}(undef, size(bc))
1109+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
1110+
MyArray2{ElType}(undef, size(bc))
1111+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
1112+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S
11041113

11051114
@testset "broadcast" begin
11061115
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
@@ -1118,19 +1127,34 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
11181127
# used inside of broadcast but we also test it here explicitly
11191128
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
11201129

1121-
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
1122-
@test_throws MethodError s .+ s
1130+
# Make sure we can handle style with similar defined
1131+
# And we can handle most conflict
1132+
# s1 and s2 has similar defined, but s3 not
1133+
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
1134+
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
1135+
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
1136+
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
1137+
s4 = StructArray{ComplexF64}((rand(2), rand(2)))
1138+
1139+
function _test_similar(a, b, c)
1140+
try
1141+
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
1142+
@test typeof(a .+ b .- c) == typeof(d)
1143+
catch
1144+
@test_throws MethodError a .+ b .- c
1145+
end
1146+
end
1147+
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
1148+
_test_similar(s, s′, s″)
1149+
end
11231150

11241151
# test for dimensionality track
1152+
s = s1
11251153
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
11261154
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
11271155
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
11281156
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
1129-
1130-
a = StructArray([1;2+im])
1131-
b = StructArray([1;;2+im])
1132-
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
1133-
@test a .+ Any[1] isa StructArray
1157+
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
11341158

11351159
# issue #185
11361160
A = StructArray(randn(ComplexF64, 3, 3))
@@ -1145,6 +1169,23 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
11451169

11461170
@test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)]
11471171
@test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3]
1172+
1173+
@testset "ambiguity check" begin
1174+
function _test(a, b, c)
1175+
if a isa StructArray || b isa StructArray || c isa StructArray
1176+
d = @inferred a .+ b .- c
1177+
@test d == collect(a) .+ collect(b) .- collect(c)
1178+
@test d isa StructArray
1179+
end
1180+
end
1181+
testset = Any[StructArray([1;2+im]),
1182+
1:2,
1183+
(1,2),
1184+
]
1185+
for aa in testset, bb in testset, cc in testset
1186+
_test(aa, bb, cc)
1187+
end
1188+
end
11481189
end
11491190

11501191
@testset "map" begin

0 commit comments

Comments
 (0)