Skip to content

Commit e4d40ea

Browse files
authored
Use linear indexing in broadcast kernel when possible (#520)
1 parent f3181d4 commit e4d40ea

File tree

4 files changed

+42
-21
lines changed

4 files changed

+42
-21
lines changed

.buildkite/pipeline.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ steps:
22
- label: "CUDA.jl"
33
plugins:
44
- JuliaCI/julia#v1:
5-
version: 1.8
5+
version: "1.10"
66
- JuliaCI/julia-coverage#v1:
77
codecov: true
88
command: |
@@ -23,7 +23,7 @@ steps:
2323
- label: "oneAPI.jl"
2424
plugins:
2525
- JuliaCI/julia#v1:
26-
version: 1.8
26+
version: "1.10"
2727
- JuliaCI/julia-coverage#v1:
2828
codecov: true
2929
command: |
@@ -48,7 +48,7 @@ steps:
4848
- label: "Metal.jl"
4949
plugins:
5050
- JuliaCI/julia#v1:
51-
version: 1.8
51+
version: "1.10"
5252
- JuliaCI/julia-coverage#v1:
5353
codecov: true
5454
command: |

src/device/indexing.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ macro linearidx(A, grididx=1, ctxsym=:ctx)
6464
quote
6565
x = $(esc(A))
6666
i = linear_index($(esc(ctxsym)), $(esc(grididx)))
67-
i > length(x) && return
67+
if !(1 <= i <= length(x))
68+
return
69+
end
6870
i
6971
end
7072
end

src/host/broadcast.jl

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
using Base.Broadcast
44

5-
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
5+
using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
66

77
# but make sure we don't dispatch to the optimized copy method that directly indexes
88
function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
@@ -32,32 +32,48 @@ end
3232
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
3333
end
3434

35-
@inline Base.copyto!(dest::AnyGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict
35+
@inline Base.copyto!(dest::AnyGPUArray, bc::Broadcasted{Nothing}) =
36+
_copyto!(dest, bc) # Keep it for ArrayConflict
3637

37-
@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc)
38+
@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) =
39+
_copyto!(dest, bc)
3840

3941
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
4042
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
4143
isempty(dest) && return dest
42-
bc′ = Broadcast.preprocess(dest, bc)
43-
44-
# grid-stride kernel
45-
function broadcast_kernel(ctx, dest, bc′, nelem)
46-
i = 0
47-
while i < nelem
48-
i += 1
49-
I = @cartesianidx(dest, i)
50-
@inbounds dest[I] = bc′[I]
44+
bc = Broadcast.preprocess(dest, bc)
45+
46+
broadcast_kernel = if ndims(dest) == 1 ||
47+
(isa(IndexStyle(dest), IndexLinear) &&
48+
isa(IndexStyle(bc), IndexLinear))
49+
function (ctx, dest, bc, nelem)
50+
i = 1
51+
while i <= nelem
52+
I = @linearidx(dest, i)
53+
@inbounds dest[I] = bc[I]
54+
i += 1
55+
end
56+
return
57+
end
58+
else
59+
function (ctx, dest, bc, nelem)
60+
i = 0
61+
while i < nelem
62+
i += 1
63+
I = @cartesianidx(dest, i)
64+
@inbounds dest[I] = bc[I]
65+
end
66+
return
5167
end
52-
return
5368
end
69+
5470
elements = length(dest)
5571
elements_per_thread = typemax(Int)
56-
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1;
72+
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1;
5773
elements, elements_per_thread)
5874
config = launch_configuration(backend(dest), heuristic;
5975
elements, elements_per_thread)
60-
gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread;
76+
gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread;
6177
threads=config.threads, blocks=config.blocks)
6278

6379
return dest
@@ -101,12 +117,15 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
101117

102118
# grid-stride kernel
103119
function map_kernel(ctx, dest, bc, nelem)
104-
for i in 1:nelem
120+
i = 1
121+
while i <= nelem
105122
j = linear_index(ctx, i)
106123
j > common_length && return
107124

108125
J = CartesianIndices(axes(bc))[j]
109126
@inbounds dest[j] = bc[J]
127+
128+
i += 1
110129
end
111130
return
112131
end

src/host/math.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
function Base.clamp!(A::AnyGPUArray, low, high)
44
gpu_call(A, low, high) do ctx, A, low, high
5-
I = @cartesianidx A
5+
I = @linearidx A
66
A[I] = clamp(A[I], low, high)
77
return
88
end

0 commit comments

Comments
 (0)