Skip to content

Commit 7358701

Browse files
committed
Make broadcast extendable
1 parent 650729d commit 7358701

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

src/broadcast.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,22 @@ end
9797
scalar_getindex(x) = x
9898
scalar_getindex(x::Ref) = x[]
9999

100-
@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
101-
first_staticarray = a[findfirst(ai -> ai <: Union{StaticArray, Transpose{<:Any, <:StaticArray}, Adjoint{<:Any, <:StaticArray}, Diagonal{<:Any, <:StaticArray}}, a)]
100+
isstatic(::StaticArray) = true
101+
isstatic(::Transpose{<:Any, <:StaticArray}) = true
102+
isstatic(::Adjoint{<:Any, <:StaticArray}) = true
103+
isstatic(::Diagonal{<:Any, <:StaticArray}) = true
104+
isstatic(_) = false
105+
106+
@inline first_statictype(x, y...) = isstatic(x) ? typeof(x) : first_statictype(y...)
107+
first_statictype() = error("unresolved dest type")
108+
109+
@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
110+
first_staticarray = first_statictype(a...)
111+
elements = __broadcast(f, sz, s, a...)
112+
@inbounds return similar_type(first_staticarray, eltype(elements), Size(newsize))(elements)
113+
end
102114

115+
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
103116
if prod(newsize) == 0
104117
# Use inference to get eltype in empty case (see also comments in _map)
105118
eltys = [:(eltype(a[$i])) for i 1:length(a)]
@@ -123,8 +136,7 @@ scalar_getindex(x::Ref) = x[]
123136

124137
return quote
125138
@_inline_meta
126-
@inbounds elements = tuple($(exprs...))
127-
@inbounds return similar_type($first_staticarray, eltype(elements), Size(newsize))(elements)
139+
@inbounds return elements = tuple($(exprs...))
128140
end
129141
end
130142

src/precompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function _precompile_()
2222
end
2323

2424
# Some expensive generators
25-
@assert precompile(Tuple{typeof(which(_broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
25+
@assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
2626
@assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any})
2727
@assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any})
2828
@assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any})

test/broadcast.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,30 @@ end
282282
end
283283

284284
end
285+
286+
# A help struct to test style-based broadcast dispatch with unknown array wrapper.
287+
# `WrapArray(A)` behaves like `A` during broadcast. But its not a `StaticArray`.
288+
struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
289+
data::P
290+
end
291+
Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...]
292+
Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...)
293+
Base.size(A::WrapArray) = size(A.data)
294+
Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)
295+
StaticArrays.isstatic(A::WrapArray) = StaticArrays.isstatic(A.data)
296+
StaticArrays.Size(::Type{WrapArray{T,N,P}}) where {T,N,P} = StaticArrays.Size(P)
297+
function StaticArrays.similar_type(::Type{WrapArray{T,N,P}}, ::Type{t}, s::Size{S}) where {T,N,P,t,S}
298+
return StaticArrays.similar_type(P, t, s)
299+
end
300+
301+
@testset "Broadcast with unknown wrapper" begin
302+
data = (1, 2)
303+
for T in (SVector{2}, MVector{2})
304+
a = T(data)
305+
b = WrapArray(a)
306+
@test @inferred(b .+ a) isa T
307+
@test @inferred(b .+ b) isa T
308+
@test @inferred(b .+ (1, 2)) isa T
309+
@test b .+ a == b .+ b == a .+ a
310+
end
311+
end

0 commit comments

Comments
 (0)