Skip to content

Commit 2239640

Browse files
bors[bot]vchuravy
andauthored
Merge #28
28: Blocked iteration r=vchuravy a=vchuravy fixes #22 @mwarusz thank you for the performance analysis. On the V100 I am running on: | Kernel | Time | Speed of Light Mem % | | ---------| ------- | - | | naive (32, 32) | 1.19ms | 65.06% | | blocked | 1.20ms | 64.38 % | | naive (1024, 1) | 3.67 ms | 26.84 % | | naive (1024, 1) Const | 1.79ms | 56.13 % | | naive (1, 1024) | 3.66ms | 49.53 % | | naive (1, 1024) Const | 3.03ms | 60.02 % | Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
2 parents 7fac139 + dfbd9a7 commit 2239640

File tree

11 files changed

+476
-229
lines changed

11 files changed

+476
-229
lines changed

examples/performance.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
using KernelAbstractions
2+
using CUDAapi
3+
4+
CUDAapi.has_cuda_gpu() || exit()
5+
6+
using CuArrays
7+
using CUDAdrv
8+
using CUDAnative
9+
using CUDAnative.NVTX
10+
11+
@kernel function transpose_kernel_naive!(b, a)
12+
I = @index(Global, Cartesian)
13+
i, j = I.I
14+
@inbounds b[i, j] = a[j, i]
15+
end
16+
17+
const block_dim = 32
18+
const grid_dim = 256
19+
20+
@kernel function transpose_kernel!(b, a)
21+
block_dim_x, block_dim_y = block_dim, block_dim
22+
grid_dim_x, grid_dim_y = grid_dim, grid_dim
23+
24+
wgsize = prod(groupsize())
25+
26+
I = @index(Global)
27+
L = @index(Local)
28+
G = div(I - 1, wgsize) + 1
29+
30+
thread_idx_x = (L - 1) % block_dim_x + 1
31+
thread_idx_y = div(L - 1, block_dim_x) + 1
32+
33+
block_idx_x = (G - 1) % grid_dim_x + 1
34+
block_idx_y = div(G - 1, grid_dim_x) + 1
35+
36+
i = (block_idx_x - 1) * block_dim_x + thread_idx_x
37+
j = (block_idx_y - 1) * block_dim_y + thread_idx_y
38+
39+
@inbounds b[i + size(b, 1) * (j - 1)] = a[j + size(a, 1) * (i - 1)]
40+
end
41+
42+
const T = Float32
43+
const N = grid_dim * block_dim
44+
const shape = N, N
45+
const nreps = 10
46+
47+
NVTX.@range "Naive transpose $block_dim, $block_dim" let
48+
a = CuArray(rand(T, shape))
49+
b = similar(a, shape[2], shape[1])
50+
kernel! = transpose_kernel_naive!(CUDA(), (block_dim, block_dim), size(b))
51+
52+
event = kernel!(b, a)
53+
wait(event)
54+
@assert Array(b) == Array(a)'
55+
@CUDAdrv.profile begin
56+
for rep in 1:nreps
57+
event = kernel!(b, a, dependencies=(event,))
58+
end
59+
wait(event)
60+
end
61+
end
62+
63+
NVTX.@range "Naive transpose $(block_dim^2), 1" let
64+
a = CuArray(rand(T, shape))
65+
b = similar(a, shape[2], shape[1])
66+
kernel! = transpose_kernel_naive!(CUDA(), (block_dim*block_dim, 1), size(b))
67+
68+
event = kernel!(b, a)
69+
wait(event)
70+
@assert Array(b) == Array(a)'
71+
@CUDAdrv.profile begin
72+
for rep in 1:nreps
73+
event = kernel!(b, a, dependencies=(event,))
74+
end
75+
wait(event)
76+
end
77+
end
78+
79+
NVTX.@range "Naive transpose 1, $(block_dim^2)" let
80+
a = CuArray(rand(T, shape))
81+
b = similar(a, shape[2], shape[1])
82+
kernel! = transpose_kernel_naive!(CUDA(), (1, block_dim*block_dim), size(b))
83+
84+
event = kernel!(b, a)
85+
wait(event)
86+
@assert Array(b) == Array(a)'
87+
@CUDAdrv.profile begin
88+
for rep in 1:nreps
89+
event = kernel!(b, a, dependencies=(event,))
90+
end
91+
wait(event)
92+
end
93+
end
94+
95+
NVTX.@range "Baseline transpose" let
96+
a = CuArray(rand(T, shape))
97+
b = similar(a, shape[2], shape[1])
98+
99+
kernel! = transpose_kernel!(CUDA(), (block_dim*block_dim), length(b))
100+
101+
event = kernel!(b, a)
102+
wait(event)
103+
@assert Array(b) == Array(a)'
104+
@CUDAdrv.profile begin
105+
for rep in 1:nreps
106+
event = kernel!(b, a, dependencies=(event,))
107+
end
108+
wait(event)
109+
end
110+
end
111+

src/KernelAbstractions.jl

Lines changed: 34 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ function async_copy! end
7676
"""
7777
groupsize()
7878
79-
Query the workgroupsize on the device.
79+
Query the workgroupsize on the device. This function returns
80+
a tuple corresponding to kernel configuration. In order to get
81+
the total size you can use `prod(groupsize())`.
8082
"""
8183
function groupsize end
8284

@@ -131,10 +133,6 @@ macro index(locale, args...)
131133
indexkind = :Linear
132134
end
133135

134-
if indexkind === :Cartesian && locale === :Local
135-
error("@index(Local, Cartesian) is not implemented yet")
136-
end
137-
138136
index_function = Symbol(:__index_, locale, :_, indexkind)
139137
Expr(:call, GlobalRef(KernelAbstractions, index_function), map(esc, args)...)
140138
end
@@ -167,31 +165,14 @@ struct CUDA <: GPU end
167165
# struct AMD <: GPU end
168166
# struct Intel <: GPU end
169167

168+
include("nditeration.jl")
169+
using .NDIteration
170+
import .NDIteration: get
171+
170172
###
171173
# Kernel closure struct
172174
###
173175

174-
import Base.@pure
175-
176-
abstract type _Size end
177-
struct DynamicSize <: _Size end
178-
struct StaticSize{S} <: _Size
179-
function StaticSize{S}() where S
180-
new{S::Tuple{Vararg{Int}}}()
181-
end
182-
end
183-
184-
@pure StaticSize(s::Tuple{Vararg{Int}}) = StaticSize{s}()
185-
@pure StaticSize(s::Int...) = StaticSize{s}()
186-
@pure StaticSize(s::Type{<:Tuple}) = StaticSize{tuple(s.parameters...)}()
187-
188-
# Some @pure convenience functions for `StaticSize`
189-
@pure get(::Type{StaticSize{S}}) where {S} = S
190-
@pure get(::StaticSize{S}) where {S} = S
191-
@pure Base.getindex(::StaticSize{S}, i::Int) where {S} = i <= length(S) ? S[i] : 1
192-
@pure Base.ndims(::StaticSize{S}) where {S} = length(S)
193-
@pure Base.length(::StaticSize{S}) where {S} = prod(S)
194-
195176
"""
196177
Kernel{Device, WorkgroupSize, NDRange, Func}
197178
@@ -206,14 +187,7 @@ end
206187
workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
207188
ndrange(::Kernel{D, WorkgroupSize, NDRange}) where {D, WorkgroupSize,NDRange} = NDRange
208189

209-
"""
210-
partition(kernel, ndrange)
211-
212-
Splits the maximum size of the iteration space by the workgroupsize.
213-
Returns the number of workgroups necessary and whether the last workgroup
214-
needs to perform dynamic bounds-checking.
215-
"""
216-
@inline function partition(kernel::Kernel, ndrange, workgroupsize)
190+
function partition(kernel, ndrange, workgroupsize)
217191
static_ndrange = KernelAbstractions.ndrange(kernel)
218192
static_workgroupsize = KernelAbstractions.workgroupsize(kernel)
219193

@@ -225,42 +199,49 @@ needs to perform dynamic bounds-checking.
225199
You created a dynamically sized kernel, but forgot to provide runtime
226200
parameters for the kernel. Either provide them statically if known
227201
or dynamically.
228-
NDRange(Static): $(typeof(static_ndrange))
202+
NDRange(Static): $(static_ndrange)
229203
NDRange(Dynamic): $(ndrange)
230-
Workgroupsize(Static): $(typeof(static_workgroupsize))
204+
Workgroupsize(Static): $(static_workgroupsize)
231205
Workgroupsize(Dynamic): $(workgroupsize)
232206
"""
233207
error(errmsg)
234208
end
235209

236-
if ndrange !== nothing && static_ndrange <: StaticSize
237-
if prod(ndrange) != prod(get(static_ndrange))
238-
error("Static NDRange and launch NDRange differ")
210+
if static_ndrange <: StaticSize
211+
if ndrange !== nothing && ndrange != get(static_ndrange)
212+
error("Static NDRange ($static_ndrange) and launch NDRange ($ndrange) differ")
239213
end
214+
ndrange = get(static_ndrange)
240215
end
241216

242217
if static_workgroupsize <: StaticSize
243-
@assert length(get(static_workgroupsize)) === 1
244-
static_workgroupsize = get(static_workgroupsize)[1]
245-
if workgroupsize !== nothing && workgroupsize != static_workgroupsize
246-
error("Static WorkgroupSize and launch WorkgroupSize differ")
218+
if workgroupsize !== nothing && workgroupsize != get(static_workgroupsize)
219+
error("Static WorkgroupSize ($static_workgroupsize) and launch WorkgroupSize $(workgroupsize) differ")
247220
end
248-
workgroupsize = static_workgroupsize
221+
workgroupsize = get(static_workgroupsize)
249222
end
223+
250224
@assert workgroupsize !== nothing
225+
@assert ndrange !== nothing
226+
blocks, workgroupsize, dynamic = NDIteration.partition(ndrange, workgroupsize)
251227

252228
if static_ndrange <: StaticSize
253-
maxsize = prod(get(static_ndrange))
254-
else
255-
maxsize = prod(ndrange)
229+
static_blocks = StaticSize{blocks}
230+
blocks = nothing
231+
else
232+
static_blocks = DynamicSize
233+
blocks = CartesianIndices(blocks)
256234
end
257235

258-
nworkgroups = fld1(maxsize, workgroupsize)
259-
dynamic = mod(maxsize, workgroupsize) != 0
260-
261-
dynamic || @assert(nworkgroups * workgroupsize == maxsize)
236+
if static_workgroupsize <: StaticSize
237+
static_workgroupsize = StaticSize{workgroupsize} # we might have padded workgroupsize
238+
workgroupsize = nothing
239+
else
240+
workgroupsize = CartesianIndices(workgroupsize)
241+
end
262242

263-
return nworkgroups, dynamic
243+
iterspace = NDRange{length(ndrange), static_blocks, static_workgroupsize}(blocks, workgroupsize)
244+
return iterspace, dynamic
264245
end
265246

266247
###
@@ -273,10 +254,7 @@ include("compiler.jl")
273254
# Compiler/Frontend
274255
###
275256

276-
@inline function __workitems_iterspace()
277-
return 1:groupsize()
278-
end
279-
257+
function __workitems_iterspace end
280258
function __validindex end
281259

282260
include("macros.jl")

src/backends/cpu.jl

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,86 +14,90 @@ function wait(ev::CPUEvent, progress=nothing)
1414
end
1515

1616
function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dependencies=nothing)
17-
if ndrange isa Int
17+
if ndrange isa Integer
1818
ndrange = (ndrange,)
1919
end
20+
if workgroupsize isa Integer
21+
workgroupsize = (workgroupsize, )
22+
end
2023
if dependencies isa Event
2124
dependencies = (dependencies,)
2225
end
26+
2327
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
24-
workgroupsize = 1024 # Vectorization, 4x unrolling, minimal grain size
28+
workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size
2529
end
26-
nblocks, dynamic = partition(obj, ndrange, workgroupsize)
30+
iterspace, dynamic = partition(obj, ndrange, workgroupsize)
2731
# partition checked that the ndrange's agreed
2832
if KernelAbstractions.ndrange(obj) <: StaticSize
2933
ndrange = nothing
3034
end
31-
if KernelAbstractions.workgroupsize(obj) <: StaticSize
32-
workgroupsize = nothing
33-
end
34-
t = Threads.@spawn begin
35+
36+
t = __run(obj, ndrange, iterspace, args, dependencies)
37+
return CPUEvent(t)
38+
end
39+
40+
# Inference barrier
41+
function __run(obj, ndrange, iterspace, args, dependencies)
42+
return Threads.@spawn begin
3543
if dependencies !== nothing
3644
Base.sync_end(map(e->e.task, dependencies))
3745
end
3846
@sync begin
39-
for I in 1:(nblocks-1)
40-
let ctx = mkcontext(obj, I, ndrange, workgroupsize)
41-
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
47+
# TODO: how do we use the information that the iteration space maps perfectly to
48+
# the ndrange without incurring a 2x compilation overhead
49+
# if dynamic
50+
for block in iterspace
51+
let ctx = mkcontextdynamic(obj, block, ndrange, iterspace)
52+
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
53+
end
4254
end
43-
end
44-
45-
if dynamic
46-
let ctx = mkcontextdynamic(obj, nblocks, ndrange, workgroupsize)
47-
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
48-
end
49-
else
50-
let ctx = mkcontext(obj, nblocks, ndrange, workgroupsize)
51-
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
52-
end
53-
end
55+
# else
56+
# for block in iterspace
57+
# let ctx = mkcontext(obj, blocks, ndrange, iterspace)
58+
# Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
59+
# end
60+
# end
61+
# end
5462
end
5563
end
56-
return CPUEvent(t)
5764
end
5865

5966
Cassette.@context CPUCtx
6067

61-
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, _workgroupsize)
62-
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), false}(I, _ndrange, _workgroupsize)
68+
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, iterspace)
69+
metadata = CompilerMetadata{ndrange(kernel), false}(I, _ndrange, iterspace)
6370
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
6471
end
6572

66-
function mkcontextdynamic(kernel::Kernel{CPU}, I, _ndrange, _workgroupsize)
67-
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), true}(I, _ndrange, _workgroupsize)
73+
function mkcontextdynamic(kernel::Kernel{CPU}, I, _ndrange, iterspace)
74+
metadata = CompilerMetadata{ndrange(kernel), true}(I, _ndrange, iterspace)
6875
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
6976
end
7077

71-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Local_Linear), idx)
72-
return idx
78+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Local_Linear), idx::CartesianIndex)
79+
indices = workitems(__iterspace(ctx.metadata))
80+
return @inbounds LinearIndices(indices)[idx]
7381
end
7482

75-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Linear), idx)
76-
workgroup = __groupindex(ctx.metadata)
77-
(workgroup - 1) * __groupsize(ctx.metadata) + idx
83+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Linear), idx::CartesianIndex)
84+
I = @inbounds expand(__iterspace(ctx.metadata), __groupindex(ctx.metadata), idx)
85+
@inbounds LinearIndices(__ndrange(ctx.metadata))[I]
7886
end
7987

80-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Local_Cartesian), idx)
81-
error("@index(Local, Cartesian) is not yet defined")
88+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Local_Cartesian), idx::CartesianIndex)
89+
return idx
8290
end
8391

84-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Cartesian), idx)
85-
workgroup = __groupindex(ctx.metadata)
86-
indices = __ndrange(ctx.metadata)
87-
lI = (workgroup - 1) * __groupsize(ctx.metadata) + idx
88-
return @inbounds indices[lI]
92+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__index_Global_Cartesian), idx::CartesianIndex)
93+
return @inbounds expand(__iterspace(ctx.metadata), __groupindex(ctx.metadata), idx)
8994
end
9095

91-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__validindex), idx)
96+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__validindex), idx::CartesianIndex)
9297
# Turns this into a noop for code where we can turn of checkbounds of
9398
if __dynamic_checkbounds(ctx.metadata)
94-
maxidx = prod(size(__ndrange(ctx.metadata)))
95-
valid = idx <= mod1(maxidx, __groupsize(ctx.metadata))
96-
return valid
99+
I = @inbounds expand(__iterspace(ctx.metadata), __groupindex(ctx.metadata), idx)
100+
return I in __ndrange(ctx.metadata)
97101
else
98102
return true
99103
end

0 commit comments

Comments
 (0)