Skip to content

Commit b19e0b7

Browse files
authored
@ka_code_llvm implementation- llvm code for kernel abstractions (#141)
1 parent 06469fa commit b19e0b7

File tree

1 file changed

+80
-2
lines changed

1 file changed

+80
-2
lines changed

src/reflection.jl

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import InteractiveUtils
2-
export @ka_code_typed
2+
export @ka_code_typed, @ka_code_llvm
33
using CUDA
44

55
using UUIDs
@@ -34,6 +34,24 @@ function ka_code_typed(kernel, argtypes; ndrange=nothing, workgroupsize=nothing,
3434
end
3535

3636

37+
function ka_code_llvm(kernel, argtypes; ndrange=nothing, workgroupsize=nothing, dependencies=nothing, kwargs...)
38+
# get the iterspace and dynamic of a kernel
39+
ndrange, workgroupsize, iterspace, dynamic = KernelAbstractions.launch_config(kernel, ndrange, workgroupsize)
40+
41+
# get the first block
42+
block = @inbounds KernelAbstractions.blocks(iterspace)[1]
43+
# get a context of the kernel based on the first block
44+
ctx = KernelAbstractions.mkcontext(kernel, block, ndrange, iterspace, dynamic)
45+
46+
# reformat
47+
if argtypes isa Type
48+
argtypes = argtypes.parameters
49+
end
50+
# use code_typed
51+
return InteractiveUtils.code_llvm(KernelAbstractions.Cassette.overdub, (typeof(ctx), typeof(kernel.f), argtypes...); kwargs...)
52+
end
53+
54+
3755
"""
3856
Get the typed IR for a kernel
3957
@@ -96,6 +114,66 @@ macro ka_code_typed(ex0...)
96114
end
97115

98116
local results = $thecall
99-
length(results) == 1 ? results[1] : results
117+
if results !== nothing
118+
length(results) == 1 ? results[1] : results
119+
end
120+
end
121+
end
122+
123+
124+
"""
125+
Get the llvm code for a kernel
126+
127+
# Examples
128+
```
129+
@ka_code_llvm kernel(args. ndrange=...)
130+
@ka_code_llvm kernel(args. ndrange=... workgroupsize=...)
131+
@ka_code_llvm optimize=false kernel(args. ndrange=...)
132+
```
133+
If ndrange is statically defined, then you could call
134+
```
135+
@ka_code_llvm kernel(args.)
136+
```
137+
Works for CPU kernels ONLY, with static or dynamic declarations
138+
"""
139+
macro ka_code_llvm(ex0...)
140+
ex = ()
141+
args = gensym(:args)
142+
old_args = nothing
143+
kern = nothing
144+
for i = 1:length(ex0)
145+
if ex0[i].head == :call
146+
# inside kernel() expr
147+
while length(ex0[i].args) > 2
148+
if isa(ex0[i].args[end], Expr)
149+
# at expr (like ndrange=10)
150+
kw = ex0[i].args[end]
151+
@assert kw.head == :kw
152+
kw.args[2] = esc(kw.args[2])
153+
kw.head = Symbol("=")
154+
resize!(ex0[i].args, length(ex0[i].args) - 1)
155+
ex = (kw,)..., ex...
156+
else
157+
# only symbols left
158+
break
159+
end
160+
end
161+
# save kernel args
162+
old_args = Expr(:tuple, map(esc, ex0[i].args[2:end])...)
163+
resize!(ex0[i].args, 2)
164+
ex0[i].args[2] = Expr(:..., args)
165+
kern = esc(ex0[i].args[1])
166+
end
167+
ex = ex..., ex0[i]
168+
end
169+
@assert(old_args != nothing)
170+
@assert(kern != nothing)
171+
172+
thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_llvm, ex)
173+
174+
quote
175+
local $(esc(args)) = $(old_args)
176+
177+
local results = $thecall
100178
end
101179
end

0 commit comments

Comments
 (0)