Skip to content

Commit bee0d4f

Browse files
committed
Make sure dispatch with dest's style work.
1 parent b53148e commit bee0d4f

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/host/broadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export AbstractGPUArrayStyle
44

55
using Base.Broadcast
66

7-
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle
7+
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
88

99
const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
1010
Base.RefValue{<:AbstractGPUArray{T}}}
@@ -47,7 +47,7 @@ end
4747
copyto!(similar(bc, ElType), bc)
4848
end
4949

50-
@inline function materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle}
50+
@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle}
5151
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
5252
end
5353

test/testsuite/broadcasting.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testsuite "broadcasting" (AT, eltypes)->begin
22
broadcasting(AT, eltypes)
33
vec3(AT, eltypes)
4+
unknown_wrapper(AT, eltypes)
45

56
@testset "type instabilities" begin
67
f(x) = x ? 1.0 : 0
@@ -205,3 +206,22 @@ function vec3(AT, eltypes)
205206
@test all(map((a,b)-> all((1,2,3) .≈ (1,2,3)), Array(res2), res2c))
206207
end
207208
end
209+
210+
struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
211+
data::P
212+
end
213+
Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...]
214+
Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...)
215+
Base.size(A::WrapArray) = size(A.data)
216+
Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)
217+
function unknown_wrapper(AT, eltypes)
218+
@views for ET in eltypes
219+
A = AT(randn(ET, 10, 10))
220+
WA = WrapArray(A)
221+
@test Array(WA .+ WA) == Array(WA .+ A) == Array(A .+ A)
222+
@test Array(WA .+ A[:,1]) == Array(A .+ A[:,1])
223+
@test Array(WA .+ A[1,:]) == Array(A .+ A[1,:])
224+
WA .= ET(1) # test for dispatch with dest's BroadcastStyle.
225+
@test all(isequal(ET(1)), Array(A))
226+
end
227+
end

0 commit comments

Comments
 (0)