Skip to content

Commit 9f14a19

Browse files
authored
Use style dispatch in broadcast(!) (#393)
Supersedes #295, relies on JuliaLang/julia#35620.
1 parent 9412fa1 commit 9f14a19

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

src/host/broadcast.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export AbstractGPUArrayStyle
44

55
using Base.Broadcast
66

7-
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle
7+
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
88

99
const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
1010
Base.RefValue{<:AbstractGPUArray{T}}}
@@ -47,7 +47,15 @@ end
4747
copyto!(similar(bc, ElType), bc)
4848
end
4949

50-
@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing})
50+
@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle}
51+
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
52+
end
53+
54+
@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict
55+
56+
@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc)
57+
58+
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
5159
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
5260
isempty(dest) && return dest
5361
bc′ = Broadcast.preprocess(dest, bc)
@@ -72,12 +80,6 @@ end
7280
return dest
7381
end
7482

75-
# Base defines this method as a performance optimization, but we don't know how to do
76-
# `fill!` in general for all `BroadcastGPUArray` so we just go straight to the fallback
77-
@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) =
78-
copyto!(dest, convert(Broadcasted{Nothing}, bc))
79-
80-
8183
## map
8284

8385
allequal(x) = true

test/testsuite/broadcasting.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testsuite "broadcasting" (AT, eltypes)->begin
22
broadcasting(AT, eltypes)
33
vec3(AT, eltypes)
4+
unknown_wrapper(AT, eltypes)
45

56
@testset "type instabilities" begin
67
f(x) = x ? 1.0 : 0
@@ -205,3 +206,34 @@ function vec3(AT, eltypes)
205206
@test all(map((a,b)-> all((1,2,3) .≈ (1,2,3)), Array(res2), res2c))
206207
end
207208
end
209+
210+
# A help struct to test style-based broadcast dispatch with unknown array wrapper.
211+
# `WrapArray(A)` behaves like `A` during broadcast. But its not a `BroadcastGPUArray`.
212+
struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
213+
data::P
214+
end
215+
Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...]
216+
Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...)
217+
Base.size(A::WrapArray) = size(A.data)
218+
# For kernal support
219+
Adapt.adapt_structure(to, s::WrapArray) = WrapArray(Adapt.adapt(to, s.data))
220+
# For broadcast support
221+
GPUArrays.backend(::Type{WrapArray{T,N,P}}) where {T,N,P} = GPUArrays.backend(P)
222+
Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)
223+
224+
function unknown_wrapper(AT, eltypes)
225+
for ET in eltypes
226+
@views @testset "unknown wrapper $ET" begin
227+
A = AT(rand(ET, 10, 10))
228+
WA = WrapArray(A)
229+
# test for dispatch with src's BroadcastStyle.
230+
@test Array(WA .+ ET(1)) == Array(A .+ ET(1))
231+
@test Array(WA .+ WA) == Array(WA .+ A) == Array(A .+ A)
232+
@test Array(WA .+ A[:,1]) == Array(A .+ A[:,1])
233+
@test Array(WA .+ A[1,:]) == Array(A .+ A[1,:])
234+
# test for dispatch with dest's BroadcastStyle.
235+
WA .= ET(1)
236+
@test all(isequal(ET(1)), Array(A))
237+
end
238+
end
239+
end

0 commit comments

Comments
 (0)