Skip to content

Commit afdace7

Browse files
authored
Code typed fixes (#144)
fixes for code typed with complicated arguments to kernels
1 parent b19e0b7 commit afdace7

File tree

2 files changed

+60
-52
lines changed

2 files changed

+60
-52
lines changed

src/reflection.jl

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -52,26 +52,7 @@ function ka_code_llvm(kernel, argtypes; ndrange=nothing, workgroupsize=nothing,
5252
end
5353

5454

55-
"""
56-
Get the typed IR for a kernel
57-
58-
# Examples
59-
```
60-
@ka_code_typed kernel(args. ndrange=...)
61-
@ka_code_typed kernel(args. ndrange=... workgroupsize=...)
62-
@ka_code_typed optimize=false kernel(args. ndrange=...)
63-
```
64-
To use interactive mode (with Cthulhu), call
65-
```
66-
@ka_code_typed interactive=true kernel(args. ndrange=...)
67-
```
68-
If ndrange is statically defined, then you could call
69-
```
70-
@ka_code_typed kernel(args.)
71-
```
72-
Works for CPU or CUDA kernels, with static or dynamic declarations
73-
"""
74-
macro ka_code_typed(ex0...)
55+
function format_ex(ex0)
7556
ex = ()
7657
args = gensym(:args)
7758
old_args = nothing
@@ -83,7 +64,10 @@ macro ka_code_typed(ex0...)
8364
if isa(ex0[i].args[end], Expr)
8465
# at expr (like ndrange=10)
8566
kw = ex0[i].args[end]
86-
@assert kw.head == :kw
67+
if kw.head != :kw
68+
# if an expr in place of a variable, skip
69+
break
70+
end
8771
kw.args[2] = esc(kw.args[2])
8872
kw.head = Symbol("=")
8973
resize!(ex0[i].args, length(ex0[i].args) - 1)
@@ -103,6 +87,31 @@ macro ka_code_typed(ex0...)
10387
end
10488
@assert(old_args != nothing)
10589
@assert(kern != nothing)
90+
return ex, args, old_args, kern
91+
end
92+
93+
94+
"""
95+
Get the typed IR for a kernel
96+
97+
# Examples
98+
```
99+
@ka_code_typed kernel(args. ndrange=...)
100+
@ka_code_typed kernel(args. ndrange=... workgroupsize=...)
101+
@ka_code_typed optimize=false kernel(args. ndrange=...)
102+
```
103+
To use interactive mode (with Cthulhu), call
104+
```
105+
@ka_code_typed interactive=true kernel(args. ndrange=...)
106+
```
107+
If ndrange is statically defined, then you could call
108+
```
109+
@ka_code_typed kernel(args.)
110+
```
111+
Works for CPU or CUDA kernels, with static or dynamic declarations
112+
"""
113+
macro ka_code_typed(ex0...)
114+
ex, args, old_args, kern = format_ex(ex0)
106115

107116
thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_typed, ex)
108117

@@ -137,43 +146,18 @@ If ndrange is statically defined, then you could call
137146
Works for CPU kernels ONLY, with static or dynamic declarations
138147
"""
139148
macro ka_code_llvm(ex0...)
140-
ex = ()
141-
args = gensym(:args)
142-
old_args = nothing
143-
kern = nothing
144-
for i = 1:length(ex0)
145-
if ex0[i].head == :call
146-
# inside kernel() expr
147-
while length(ex0[i].args) > 2
148-
if isa(ex0[i].args[end], Expr)
149-
# at expr (like ndrange=10)
150-
kw = ex0[i].args[end]
151-
@assert kw.head == :kw
152-
kw.args[2] = esc(kw.args[2])
153-
kw.head = Symbol("=")
154-
resize!(ex0[i].args, length(ex0[i].args) - 1)
155-
ex = (kw,)..., ex...
156-
else
157-
# only symbols left
158-
break
159-
end
160-
end
161-
# save kernel args
162-
old_args = Expr(:tuple, map(esc, ex0[i].args[2:end])...)
163-
resize!(ex0[i].args, 2)
164-
ex0[i].args[2] = Expr(:..., args)
165-
kern = esc(ex0[i].args[1])
166-
end
167-
ex = ex..., ex0[i]
168-
end
169-
@assert(old_args != nothing)
170-
@assert(kern != nothing)
149+
ex, args, old_args, kern = format_ex(ex0)
171150

172151
thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_llvm, ex)
173152

174153
quote
175154
local $(esc(args)) = $(old_args)
176155

156+
if isa($kern, Kernel{CUDADevice})
157+
# does not support CUDA kernels
158+
error("@ka_code_llvm does not support CUDA kernels")
159+
end
160+
177161
local results = $thecall
178162
end
179163
end

test/reflection.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ end
1010
A[I] = B[I] + C[I]
1111
end
1212

13+
@kernel function addi(A, C, i)
14+
I = @index(Global)
15+
A[I] = i + C[I]
16+
end
17+
1318
function test_typed_kernel_dynamic()
1419
A = ones(1024, 1024)
1520
kernel = mul2(CPU())
@@ -84,7 +89,26 @@ function test_typed_kernel_no_optimize()
8489
end
8590
end
8691

92+
function test_expr_kernel()
93+
A = ones(1024, 1024)
94+
C = similar(A)
95+
kernel = addi(CPU())
96+
res = @ka_code_typed kernel(A, C, 1+2, ndrange=size(A))
97+
@test isa(res, Pair{Core.CodeInfo, DataType})
98+
@test isa(res[1].code, Array{Any,1})
99+
100+
if has_cuda_gpu()
101+
A = CUDA.ones(1024, 1024)
102+
C = similar(A)
103+
kernel = addi(CUDADevice(), (32, 32))
104+
res = @ka_code_typed kernel(A, C, 1+2, ndrange=size(A))
105+
@test isa(res, Pair{Core.CodeInfo, DataType})
106+
@test isa(res[1].code, Array{Any,1})
107+
end
108+
end
109+
87110
test_typed_kernel_dynamic()
88111
test_typed_kernel_dynamic_no_info()
89112
test_typed_kernel_static()
90113
test_typed_kernel_no_optimize()
114+
test_expr_kernel()

0 commit comments

Comments
 (0)