Skip to content

Commit dc42880

Browse files
authored
Revert "Avoid cartesian iteration where possible. (#454)" (#463)
This reverts commit 36661e3.
1 parent 61943b5 commit dc42880

File tree

2 files changed

+9
-40
lines changed

2 files changed

+9
-40
lines changed

src/host/broadcast.jl

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -51,53 +51,22 @@ end
5151
bc′ = Broadcast.preprocess(dest, bc)
5252

5353
# grid-stride kernel
54-
function broadcast_kernel(ctx, dest, ::Val{Is}, bc′, nelem) where Is
55-
j = 0
56-
while j < nelem
57-
j += 1
58-
59-
i = @linearidx(dest, j)
60-
61-
# cartesian indexing is slow, so avoid it if possible
62-
if isa(IndexStyle(dest), IndexCartesian) || isa(IndexStyle(bc′), IndexCartesian)
63-
# this performs an integer division, which is expensive. to make it possible
64-
# for the compiler to optimize it away, we put the iterator in the type
65-
# domain so that the indices are available at compile time. note that LLVM
66-
# only seems to replace pow2 divisions (with bitshifts), but other back-ends
67-
# may be smarter and replace arbitrary divisions by bit operations.
68-
#
69-
# also see maleadt/StaticCartesian.jl, which implements this in Julia,
70-
# but does not result in an additional speed-up on tested back-ends.
71-
#
72-
# in addition, we use @inbounds to avoid bounds checks, but we also need to
73-
# inform the compiler about the bounds that we are assuming. this is done
74-
# using the assume intrinsic, and in case of Metal yields a 8x speed-up.
75-
assume(1 <= i <= length(Is))
76-
I = @inbounds Is[i]
77-
end
78-
79-
val = if isa(IndexStyle(bc′), IndexCartesian)
80-
@inbounds bc′[I]
81-
else
82-
@inbounds bc′[i]
83-
end
84-
85-
if isa(IndexStyle(dest), IndexCartesian)
86-
@inbounds dest[I] = val
87-
else
88-
@inbounds dest[i] = val
89-
end
54+
function broadcast_kernel(ctx, dest, bc′, nelem)
55+
i = 0
56+
while i < nelem
57+
i += 1
58+
I = @cartesianidx(dest, i)
59+
@inbounds dest[I] = bc′[I]
9060
end
9161
return
9262
end
9363
elements = length(dest)
9464
elements_per_thread = typemax(Int)
95-
Is = CartesianIndices(dest)
96-
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(Is), bc′, 1;
65+
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1;
9766
elements, elements_per_thread)
9867
config = launch_configuration(backend(dest), heuristic;
9968
elements, elements_per_thread)
100-
gpu_call(broadcast_kernel, dest, Val(Is), bc′, config.elements_per_thread;
69+
gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread;
10170
threads=config.threads, blocks=config.blocks)
10271

10372
return dest

src/host/math.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
function Base.clamp!(A::AnyGPUArray, low, high)
44
gpu_call(A, low, high) do ctx, A, low, high
5-
I = @linearidx A
5+
I = @cartesianidx A
66
A[I] = clamp(A[I], low, high)
77
return
88
end

0 commit comments

Comments
 (0)