Skip to content

Commit fef311e

Browse files
committed
Try to resolve style conflict.
1 parent 4056c71 commit fef311e

File tree

2 files changed

+82
-21
lines changed

2 files changed

+82
-21
lines changed

src/structarray.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
486486
end
487487

488488
# broadcast
489-
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
489+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
490490

491491
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
492492

@@ -496,19 +496,39 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
496496
return StructArrayStyle{T, N}()
497497
end
498498

499+
# StructArrayStyle is a wrapped style.
500+
# Here we try our best to resolve style conflict.
501+
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
502+
N′ = M === Any || N === Any ? Any : max(M, N)
503+
S′ = Broadcast.result_style(S(), b)
504+
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
505+
end
506+
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
507+
499508
@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
500509
combine_style_types(BroadcastStyle(A), args...)
501510
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
502511
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
512+
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
503513
combine_style_types(s::BroadcastStyle) = s
504514

505515
Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)
506516

507517
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
508518

509-
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:Union{DefaultArrayStyle,StructArrayStyle}, N, ElType}
510-
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
511-
return similar(ContainerType, axes(bc))
519+
# Here we use `similar` defined for `S` to build the dest Array.
520+
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
521+
bc′ = convert(Broadcasted{S}, bc)
522+
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
523+
end
524+
525+
# Unwrapper to recover the behaviour defined by parent style.
526+
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
527+
return copyto!(dest, convert(Broadcasted{S}, bc))
528+
end
529+
530+
@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
531+
return Broadcast.materialize!(S(), dest, bc)
512532
end
513533

514534
# for aliasing analysis during broadcast

test/runtests.jl

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,17 +1100,26 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
11001100
@test t.b.d isa Array
11011101
end
11021102

1103-
struct MyArray{T,N} <: AbstractArray{T,N}
1104-
A::Array{T,N}
1103+
for S in (1, 2, 3)
1104+
MyArray = Symbol(:MyArray, S)
1105+
@eval begin
1106+
struct $MyArray{T,N} <: AbstractArray{T,N}
1107+
A::Array{T,N}
1108+
end
1109+
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
1110+
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
1111+
Base.getindex(A::$MyArray, i::Int) = A.A[i]
1112+
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
1113+
Base.size(A::$MyArray) = Base.size(A.A)
1114+
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
1115+
end
11051116
end
1106-
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
1107-
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
1108-
Base.getindex(A::MyArray, i::Int) = A.A[i]
1109-
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
1110-
Base.size(A::MyArray) = Base.size(A.A)
1111-
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
1112-
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
1113-
MyArray{ElType}(undef, size(bc))
1117+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
1118+
MyArray1{ElType}(undef, size(bc))
1119+
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
1120+
MyArray2{ElType}(undef, size(bc))
1121+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
1122+
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S
11141123

11151124
@testset "broadcast" begin
11161125
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
@@ -1128,19 +1137,34 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
11281137
# used inside of broadcast but we also test it here explicitly
11291138
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})
11301139

1131-
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
1132-
@test_throws MethodError s .+ s
1140+
# Make sure we can handle style with similar defined
1141+
# And we can handle most conflict
1142+
# s1 and s2 has similar defined, but s3 not
1143+
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
1144+
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
1145+
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
1146+
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
1147+
s4 = StructArray{ComplexF64}((rand(2), rand(2)))
1148+
1149+
function _test_similar(a, b, c)
1150+
try
1151+
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
1152+
@test typeof(a .+ b .- c) == typeof(d)
1153+
catch
1154+
@test_throws MethodError a .+ b .- c
1155+
end
1156+
end
1157+
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
1158+
_test_similar(s, s′, s″)
1159+
end
11331160

11341161
# test for dimensionality track
1162+
s = s1
11351163
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
11361164
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
11371165
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
11381166
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
1139-
1140-
a = StructArray([1;2+im])
1141-
b = StructArray([1;;2+im])
1142-
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
1143-
@test a .+ Any[1] isa StructArray
1167+
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
11441168

11451169
# issue #185
11461170
A = StructArray(randn(ComplexF64, 3, 3))
@@ -1155,6 +1179,23 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
11551179

11561180
@test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)]
11571181
@test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3]
1182+
1183+
@testset "ambiguity check" begin
1184+
function _test(a, b, c)
1185+
if a isa StructArray || b isa StructArray || c isa StructArray
1186+
d = @inferred a .+ b .- c
1187+
@test d == collect(a) .+ collect(b) .- collect(c)
1188+
@test d isa StructArray
1189+
end
1190+
end
1191+
testset = Any[StructArray([1;2+im]),
1192+
1:2,
1193+
(1,2),
1194+
]
1195+
for aa in testset, bb in testset, cc in testset
1196+
_test(aa, bb, cc)
1197+
end
1198+
end
11581199
end
11591200

11601201
@testset "map" begin

0 commit comments

Comments
 (0)