Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 9459fce

Browse files
committed
Optimize accumulate/logical indexing.
Use the launch configuration API, and eagerly put temporary arrays back in the pool.
1 parent 0f79e2c commit 9459fce

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

src/accumulate.jl

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,37 @@ function Base._accumulate!(op::Function, vout::CuVector{T}, v::CuVector, dims::N
2525
Δ = 1 # Δ = 2^d
2626
n = ceil(Int, log2(length(v)))
2727

28-
num_threads = 256
29-
num_blocks = ceil(Int, length(v) / num_threads)
28+
# partial in-place accumulation
29+
function kernel(op, vout, vin, Δ)
30+
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
3031

31-
for d in 0:n # passes through data
32-
@cuda blocks=num_blocks threads=num_threads _partial_accumulate!(op, vout, vin, Δ)
32+
@inbounds if i <= length(vin)
33+
if i > Δ
34+
vout[i] = op(vin[i - Δ], vin[i])
35+
else
36+
vout[i] = vin[i]
37+
end
38+
end
3339

34-
vin, vout = vout, vin
35-
Δ *= 2
40+
return
3641
end
3742

38-
return vin
39-
end
43+
function configurator(kernel)
44+
fun = kernel.fun
45+
config = launch_configuration(fun)
46+
blocks = cld(length(v), config.threads)
4047

41-
function _partial_accumulate!(op, vout, vin, Δ)
42-
@inbounds begin
43-
k = threadIdx().x + (blockIdx().x - 1) * blockDim().x
48+
return (threads=config.threads, blocks=blocks)
49+
end
4450

45-
if k <= length(vin)
46-
if k > Δ
47-
vout[k] = op(vin[k - Δ], vin[k])
48-
else
49-
vout[k] = vin[k]
50-
end
51-
end
51+
for d in 0:n # passes through data
52+
@cuda config=configurator kernel(op, vout, vin, Δ)
53+
54+
vin, vout = vout, vin
55+
Δ *= 2
5256
end
5357

54-
return
58+
return vin
5559
end
5660

5761
Base.accumulate_pairwise!(op, result::CuVector, v::CuVector) = accumulate!(op, result, v)

src/indexing.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
2323
ys = CuArray{T}(undef, n)
2424

2525
if n > 0
26-
num_threads = min(n, 256)
27-
num_blocks = ceil(Int, length(indices) / num_threads)
28-
2926
function kernel(ys::CuDeviceArray{T}, xs::CuDeviceArray{T}, bools, indices)
3027
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
3128

@@ -38,9 +35,19 @@ function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
3835
return
3936
end
4037

41-
@cuda blocks=num_blocks threads=num_threads kernel(ys, xs, bools, indices)
38+
function configurator(kernel)
39+
fun = kernel.fun
40+
config = launch_configuration(fun)
41+
blocks = cld(length(indices), config.threads)
42+
43+
return (threads=config.threads, blocks=blocks)
44+
end
45+
46+
@cuda config=configurator kernel(ys, xs, bools, indices)
4247
end
4348

49+
unsafe_free!(indices)
50+
4451
return ys
4552
end
4653

0 commit comments

Comments
 (0)