Skip to content

Commit fd43575

Browse files
authored
Fixes and test for at_ka_code_typed
1 parent 548ff24 commit fd43575

File tree

3 files changed

+123
-20
lines changed

3 files changed

+123
-20
lines changed

src/backends/cuda.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,28 +134,41 @@ end
134134
###
135135
# Kernel launch
136136
###
137-
function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=nothing, workgroupsize=nothing, progress=yield)
137+
function launch_config(kernel::Kernel{CUDADevice}, ndrange, workgroupsize)
138138
if ndrange isa Integer
139139
ndrange = (ndrange,)
140140
end
141141
if workgroupsize isa Integer
142-
workgroupsize = (workgroupsize,)
142+
workgroupsize = (workgroupsize, )
143143
end
144144

145-
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
145+
if KernelAbstractions.workgroupsize(kernel) <: DynamicSize && workgroupsize === nothing
146146
# TODO: allow for NDRange{1, DynamicSize, DynamicSize}(nothing, nothing)
147147
# and actually use CUDA autotuning
148148
workgroupsize = (256,)
149149
end
150+
151+
# partition checked that the ndrange's agreed
152+
if KernelAbstractions.ndrange(kernel) <: StaticSize
153+
ndrange = nothing
154+
end
155+
156+
iterspace, dynamic = partition(kernel, ndrange, workgroupsize)
157+
158+
return ndrange, workgroupsize, iterspace, dynamic
159+
end
160+
161+
function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=nothing, workgroupsize=nothing, progress=yield)
162+
163+
ndrange, workgroupsize, iterspace, dynamic = launch_config(obj, ndrange, workgroupsize)
164+
150165
# If the kernel is statically sized we can tell the compiler about that
151166
if KernelAbstractions.workgroupsize(obj) <: StaticSize
152167
maxthreads = prod(get(KernelAbstractions.workgroupsize(obj)))
153168
else
154169
maxthreads = nothing
155170
end
156171

157-
iterspace, dynamic = partition(obj, ndrange, workgroupsize)
158-
159172
nblocks = length(blocks(iterspace))
160173
threads = length(workitems(iterspace))
161174

src/reflection.jl

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
import InteractiveUtils
22
export @ka_code_typed
3+
using CUDA
34

45
function ka_code_typed(kernel, argtypes; ndrange=nothing, workgroupsize=nothing, dependencies=nothing, kwargs...)
56
# get the iterspace and dynamic of a kernel
67
ndrange, workgroupsize, iterspace, dynamic = KernelAbstractions.launch_config(kernel, ndrange, workgroupsize)
7-
# get the first block
8-
block = @inbounds KernelAbstractions.blocks(iterspace)[1]
9-
# get a context of the kernel based on the first block
10-
ctx = KernelAbstractions.mkcontext(kernel, block, ndrange, iterspace, dynamic)
8+
9+
if isa(kernel, Kernel{CPU})
10+
# get the first block
11+
block = @inbounds KernelAbstractions.blocks(iterspace)[1]
12+
# get a context of the kernel based on the first block
13+
ctx = KernelAbstractions.mkcontext(kernel, block, ndrange, iterspace, dynamic)
14+
else
15+
ctx = KernelAbstractions.mkcontext(kernel, ndrange, iterspace)
16+
end
1117
# reformat
1218
if argtypes isa Type
1319
argtypes = argtypes.parameters
@@ -17,25 +23,63 @@ function ka_code_typed(kernel, argtypes; ndrange=nothing, workgroupsize=nothing,
1723
end
1824

1925

26+
"""
27+
Get the typed IR for a kernel
28+
29+
# Examples
30+
```
31+
@ka_code_typed kernel(args. ndrange=...)
32+
@ka_code_typed kernel(args. ndrange=... workgroupsize=...)
33+
@ka_code_typed optimize=false kernel(args. ndrange=...)
34+
```
35+
If ndrange is statically defined, then you could call
36+
```
37+
@ka_code_typed kernel(args.)
38+
```
39+
Works for CPU or CUDA kernels, with static or dynamic declarations
40+
"""
2041
macro ka_code_typed(ex0...)
2142
ex = ()
43+
args = gensym(:args)
44+
old_args = nothing
45+
kern = nothing
2246
for i = 1:length(ex0)
2347
if ex0[i].head == :call
48+
# inside kernel() expr
2449
while length(ex0[i].args) > 2
25-
kw = ex0[i].args[end]
26-
@assert kw.head == :kw
27-
kw.args[2] = esc(kw.args[2])
28-
kw.head = Symbol("=")
29-
resize!(ex0[i].args, length(ex0[i].args) - 1)
30-
ex = (kw,)..., ex...
50+
if isa(ex0[i].args[end], Expr)
51+
# at expr (like ndrange=10)
52+
kw = ex0[i].args[end]
53+
@assert kw.head == :kw
54+
kw.args[2] = esc(kw.args[2])
55+
kw.head = Symbol("=")
56+
resize!(ex0[i].args, length(ex0[i].args) - 1)
57+
ex = (kw,)..., ex...
58+
else
59+
# only symbols left
60+
break
61+
end
3162
end
63+
# save kernel args
64+
old_args = Expr(:tuple, map(esc, ex0[i].args[2:end])...)
65+
resize!(ex0[i].args, 2)
66+
ex0[i].args[2] = Expr(:..., args)
67+
kern = esc(ex0[i].args[1])
3268
end
3369
ex = ex..., ex0[i]
3470
end
71+
@assert(old_args != nothing)
72+
@assert(kern != nothing)
3573

3674
thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_typed, ex)
3775

3876
quote
77+
local $(esc(args)) = $(old_args)
78+
if isa($kern, Kernel{CUDADevice})
79+
# translate CuArray to CuDeviceArray
80+
$(esc(args)) = map(CUDA.cudaconvert, $(esc(args)))
81+
end
82+
3983
local results = $thecall
4084
length(results) == 1 ? results[1] : results
4185
end

test/reflection.jl

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,49 @@
1-
using KernelAbstractions, Test
1+
using KernelAbstractions, Test, CUDA
22

33
@kernel function mul2(A)
44
I = @index(Global)
55
A[I] = 2 * A[I]
66
end
77

8+
@kernel function add3(A, B, C)
9+
I = @index(Global)
10+
A[I] = B[I] + C[I]
11+
end
12+
813
function test_typed_kernel_dynamic()
914
A = ones(1024, 1024)
1015
kernel = mul2(CPU())
1116
res = @ka_code_typed kernel(A, ndrange=size(A), workgroupsize=16)
1217
@test isa(res, Pair{Core.CodeInfo, DataType})
1318
@test isa(res[1].code, Array{Any,1})
19+
20+
if has_cuda_gpu()
21+
A = CUDA.ones(1024, 1024)
22+
kernel = mul2(CUDADevice())
23+
res = @ka_code_typed kernel(A, ndrange=size(A), workgroupsize=(32, 32))
24+
@test isa(res, Pair{Core.CodeInfo, DataType})
25+
@test isa(res[1].code, Array{Any,1})
26+
end
1427
end
1528

1629
function test_typed_kernel_dynamic_no_info()
1730
A = ones(1024, 1024)
18-
kernel = mul2(CPU())
19-
res = @ka_code_typed kernel(A, ndrange=size(A))
31+
B = similar(A)
32+
C = similar(A)
33+
kernel = add3(CPU())
34+
res = @ka_code_typed kernel(A, B, C, ndrange=size(A))
2035
@test isa(res, Pair{Core.CodeInfo, DataType})
2136
@test isa(res[1].code, Array{Any,1})
37+
38+
if has_cuda_gpu()
39+
A = CUDA.ones(1024, 1024)
40+
B = similar(A)
41+
C = similar(A)
42+
kernel = add3(CUDADevice())
43+
res = @ka_code_typed kernel(A, B, C, ndrange=size(A))
44+
@test isa(res, Pair{Core.CodeInfo, DataType})
45+
@test isa(res[1].code, Array{Any,1})
46+
end
2247
end
2348

2449
function test_typed_kernel_static()
@@ -27,15 +52,36 @@ function test_typed_kernel_static()
2752
res = @ka_code_typed kernel(A, ndrange=size(A))
2853
@test isa(res, Pair{Core.CodeInfo, DataType})
2954
@test isa(res[1].code, Array{Any,1})
55+
56+
if has_cuda_gpu()
57+
A = CUDA.ones(1024, 1024)
58+
kernel = mul2(CUDADevice(), (32, 32))
59+
res = @ka_code_typed kernel(A, ndrange=size(A))
60+
@test isa(res, Pair{Core.CodeInfo, DataType})
61+
@test isa(res[1].code, Array{Any,1})
62+
end
3063
end
3164

3265
function test_typed_kernel_no_optimize()
3366
A = ones(1024, 1024)
34-
kernel = mul2(CPU(), 16)
67+
B = similar(A)
68+
C = similar(A)
69+
kernel = add3(CPU(), 16)
3570
res = @ka_code_typed optimize=false kernel(A, ndrange=size(A))
36-
@test isa(res, Pair{Core.CodeInfo, DataType})
71+
@test isa(res, Pair{Core.CodeInfo,Core.TypeofBottom})
3772
res_opt = @ka_code_typed kernel(A, ndrange=size(A))
3873
@test size(res[1].code) < size(res_opt[1].code)
74+
75+
if has_cuda_gpu()
76+
A = CUDA.ones(1024, 1024)
77+
B = similar(A)
78+
C = similar(A)
79+
kernel = add3(CUDADevice(), (32, 32))
80+
res = @ka_code_typed optimize=false kernel(A, ndrange=size(A))
81+
@test isa(res, Pair{Core.CodeInfo,Core.TypeofBottom})
82+
res_opt = @ka_code_typed kernel(A, ndrange=size(A))
83+
@test size(res[1].code) < size(res_opt[1].code)
84+
end
3985
end
4086

4187
test_typed_kernel_dynamic()

0 commit comments

Comments
 (0)