|
51 | 51 | bc′ = Broadcast.preprocess(dest, bc)
|
52 | 52 |
|
53 | 53 | # grid-stride kernel
|
54 |
| - function broadcast_kernel(ctx, dest, ::Val{Is}, bc′, nelem) where Is |
55 |
| - j = 0 |
56 |
| - while j < nelem |
57 |
| - j += 1 |
58 |
| - |
59 |
| - i = @linearidx(dest, j) |
60 |
| - |
61 |
| - # cartesian indexing is slow, so avoid it if possible |
62 |
| - if isa(IndexStyle(dest), IndexCartesian) || isa(IndexStyle(bc′), IndexCartesian) |
63 |
| - # this performs an integer division, which is expensive. to make it possible |
64 |
| - # for the compiler to optimize it away, we put the iterator in the type |
65 |
| - # domain so that the indices are available at compile time. note that LLVM |
66 |
| - # only seems to replace pow2 divisions (with bitshifts), but other back-ends |
67 |
| - # may be smarter and replace arbitrary divisions by bit operations. |
68 |
| - # |
69 |
| - # also see maleadt/StaticCartesian.jl, which implements this in Julia, |
70 |
| - # but does not result in an additional speed-up on tested back-ends. |
71 |
| - # |
72 |
| - # in addition, we use @inbounds to avoid bounds checks, but we also need to |
73 |
| - # inform the compiler about the bounds that we are assuming. this is done |
74 |
| - # using the assume intrinsic, and in case of Metal yields a 8x speed-up. |
75 |
| - assume(1 <= i <= length(Is)) |
76 |
| - I = @inbounds Is[i] |
77 |
| - end |
78 |
| - |
79 |
| - val = if isa(IndexStyle(bc′), IndexCartesian) |
80 |
| - @inbounds bc′[I] |
81 |
| - else |
82 |
| - @inbounds bc′[i] |
83 |
| - end |
84 |
| - |
85 |
| - if isa(IndexStyle(dest), IndexCartesian) |
86 |
| - @inbounds dest[I] = val |
87 |
| - else |
88 |
| - @inbounds dest[i] = val |
89 |
| - end |
| 54 | + function broadcast_kernel(ctx, dest, bc′, nelem) |
| 55 | + i = 0 |
| 56 | + while i < nelem |
| 57 | + i += 1 |
| 58 | + I = @cartesianidx(dest, i) |
| 59 | + @inbounds dest[I] = bc′[I] |
90 | 60 | end
|
91 | 61 | return
|
92 | 62 | end
|
93 | 63 | elements = length(dest)
|
94 | 64 | elements_per_thread = typemax(Int)
|
95 |
| - Is = CartesianIndices(dest) |
96 |
| - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(Is), bc′, 1; |
| 65 | + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1; |
97 | 66 | elements, elements_per_thread)
|
98 | 67 | config = launch_configuration(backend(dest), heuristic;
|
99 | 68 | elements, elements_per_thread)
|
100 |
| - gpu_call(broadcast_kernel, dest, Val(Is), bc′, config.elements_per_thread; |
| 69 | + gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread; |
101 | 70 | threads=config.threads, blocks=config.blocks)
|
102 | 71 |
|
103 | 72 | return dest
|
|
0 commit comments