|
2 | 2 |
|
3 | 3 | using Base.Broadcast
|
4 | 4 |
|
5 |
| -import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate |
| 5 | +using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate |
6 | 6 |
|
7 | 7 | # but make sure we don't dispatch to the optimized copy method that directly indexes
|
8 | 8 | function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
|
|
32 | 32 | return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
|
33 | 33 | end
|
34 | 34 |
|
35 |
| -@inline Base.copyto!(dest::AnyGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict |
| 35 | +@inline Base.copyto!(dest::AnyGPUArray, bc::Broadcasted{Nothing}) = |
| 36 | + _copyto!(dest, bc) # Keep it for ArrayConflict |
36 | 37 |
|
37 |
| -@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc) |
| 38 | +@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = |
| 39 | + _copyto!(dest, bc) |
38 | 40 |
|
39 | 41 | @inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
|
40 | 42 | axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
|
41 | 43 | isempty(dest) && return dest
|
42 |
| - bc′ = Broadcast.preprocess(dest, bc) |
43 |
| - |
44 |
| - # grid-stride kernel |
45 |
| - function broadcast_kernel(ctx, dest, bc′, nelem) |
46 |
| - i = 0 |
47 |
| - while i < nelem |
48 |
| - i += 1 |
49 |
| - I = @cartesianidx(dest, i) |
50 |
| - @inbounds dest[I] = bc′[I] |
| 44 | + bc = Broadcast.preprocess(dest, bc) |
| 45 | + |
| 46 | + broadcast_kernel = if ndims(dest) == 1 || |
| 47 | + (isa(IndexStyle(dest), IndexLinear) && |
| 48 | + isa(IndexStyle(bc), IndexLinear)) |
| 49 | + function (ctx, dest, bc, nelem) |
| 50 | + i = 1 |
| 51 | + while i <= nelem |
| 52 | + I = @linearidx(dest, i) |
| 53 | + @inbounds dest[I] = bc[I] |
| 54 | + i += 1 |
| 55 | + end |
| 56 | + return |
| 57 | + end |
| 58 | + else |
| 59 | + function (ctx, dest, bc, nelem) |
| 60 | + i = 0 |
| 61 | + while i < nelem |
| 62 | + i += 1 |
| 63 | + I = @cartesianidx(dest, i) |
| 64 | + @inbounds dest[I] = bc[I] |
| 65 | + end |
| 66 | + return |
51 | 67 | end
|
52 |
| - return |
53 | 68 | end
|
| 69 | + |
54 | 70 | elements = length(dest)
|
55 | 71 | elements_per_thread = typemax(Int)
|
56 |
| - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1; |
| 72 | + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1; |
57 | 73 | elements, elements_per_thread)
|
58 | 74 | config = launch_configuration(backend(dest), heuristic;
|
59 | 75 | elements, elements_per_thread)
|
60 |
| - gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread; |
| 76 | + gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread; |
61 | 77 | threads=config.threads, blocks=config.blocks)
|
62 | 78 |
|
63 | 79 | return dest
|
@@ -101,12 +117,15 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
|
101 | 117 |
|
102 | 118 | # grid-stride kernel
|
103 | 119 | function map_kernel(ctx, dest, bc, nelem)
|
104 |
| - for i in 1:nelem |
| 120 | + i = 1 |
| 121 | + while i <= nelem |
105 | 122 | j = linear_index(ctx, i)
|
106 | 123 | j > common_length && return
|
107 | 124 |
|
108 | 125 | J = CartesianIndices(axes(bc))[j]
|
109 | 126 | @inbounds dest[j] = bc[J]
|
| 127 | + |
| 128 | + i += 1 |
110 | 129 | end
|
111 | 130 | return
|
112 | 131 | end
|
|
0 commit comments