Skip to content

Commit 52db290

Browse files
committed
Revert "remocing heuristic"
This reverts commit 0c7e26b.
1 parent 00c8dd4 commit 52db290

File tree

3 files changed

+60
-6
lines changed

3 files changed

+60
-6
lines changed

src/GPUArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using Reexport
1616
@reexport using GPUArraysCore
1717

1818
## executed on-device
19+
include("device/execution.jl")
1920
include("device/abstractarray.jl")
2021

2122
using KernelAbstractions

src/device/execution.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# kernel execution
2+
3+
# how many threads and blocks `kernel` needs to be launched with, passing arguments `args`,
4+
# to fully saturate the GPU. `elements` indicates the number of elements that needs to be
5+
# processed, while `elements_per_threads` indicates the number of elements this kernel can
6+
# process (i.e. if it's a grid-stride kernel, or 1 if otherwise).
7+
#
8+
# this heuristic should be specialized for the back-end, ideally using an API for maximizing
9+
# the occupancy of the launch configuration (like CUDA's occupancy API).
10+
function launch_heuristic(backend::B, kernel, args...;
11+
elements::Int,
12+
elements_per_thread::Int) where B <: Backend
13+
return (threads=256, blocks=32)
14+
end
15+
16+
# determine how many threads and blocks to actually launch given upper limits.
17+
# returns a tuple of blocks, threads, and elements_per_thread (which is always 1
18+
# unless specified that the kernel can handle a number of elements per thread)
19+
function launch_configuration(backend::B, heuristic;
20+
elements::Int,
21+
elements_per_thread::Int) where B <: Backend
22+
threads = clamp(elements, 1, heuristic.threads)
23+
blocks = max(cld(elements, threads), 1)
24+
25+
if elements_per_thread > 1 && blocks > heuristic.blocks
26+
# we want to launch more blocks than required, so prefer a grid-stride loop instead
27+
## try to stick to the number of blocks that the heuristic suggested
28+
blocks = heuristic.blocks
29+
nelem = cld(elements, blocks*threads)
30+
## only bump the number of blocks if we really need to
31+
if nelem > elements_per_thread
32+
nelem = elements_per_thread
33+
blocks = cld(elements, nelem*threads)
34+
end
35+
(; threads, blocks, elements_per_thread=nelem)
36+
else
37+
(; threads, blocks, elements_per_thread=1)
38+
end
39+
end

src/host/broadcast.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,28 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
117117
end
118118

119119
# grid-stride kernel
120-
@kernel function map_kernel(dest, bc)
121-
j = @index(Global, Linear)
122-
@inbounds dest[j] = bc[j]
120+
@kernel function map_kernel(dest, bc, nelem, common_length)
121+
122+
j = 0
123+
J = @index(Global, Linear)
124+
for i in 1:nelem
125+
j += 1
126+
if j <= common_length
127+
128+
J_c = CartesianIndices(axes(bc))[(J-1)*nelem + j]
129+
@inbounds dest[J_c] = bc[J_c]
130+
end
131+
end
123132
end
124-
133+
elements = common_length
134+
elements_per_thread = typemax(Int)
125135
kernel = map_kernel(get_backend(dest))
126-
config = KernelAbstractions.launch_config(kernel, common_length, nothing)
127-
kernel(dest, bc; ndrange = config[1], workgroupsize = config[2])
136+
heuristic = launch_heuristic(get_backend(dest), kernel, dest, bc, 1,
137+
common_length; elements, elements_per_thread)
138+
config = launch_configuration(get_backend(dest), heuristic;
139+
elements, elements_per_thread)
140+
kernel(dest, bc, config.elements_per_thread,
141+
common_length; ndrange = config.threads)
128142

129143
if eltype(dest) <: BrokenBroadcast
130144
throw(ArgumentError("Map operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))

0 commit comments

Comments
 (0)