Skip to content

Commit 6732699

Browse files
committed
Update broadcast.jl
1 parent 6e47ff5 commit 6732699

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/host/broadcast.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,15 @@ end
4747
copyto!(similar(bc, ElType), bc)
4848
end
4949

50-
@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing})
50+
@inline function materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle}
51+
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
52+
end
53+
54+
@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict
55+
56+
@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc)
57+
58+
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
5159
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
5260
isempty(dest) && return dest
5361
bc′ = Broadcast.preprocess(dest, bc)
@@ -72,12 +80,6 @@ end
7280
return dest
7381
end
7482

75-
# Base defines this method as a performance optimization, but we don't know how to do
76-
# `fill!` in general for all `BroadcastGPUArray` so we just go straight to the fallback
77-
@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) =
78-
copyto!(dest, convert(Broadcasted{Nothing}, bc))
79-
80-
8183
## map
8284

8385
allequal(x) = true

0 commit comments

Comments
 (0)