|
1 | 1 | import InteractiveUtils
|
2 |
| -export @ka_code_typed |
| 2 | +export @ka_code_typed, @ka_code_llvm |
3 | 3 | using CUDA
|
4 | 4 |
|
5 | 5 | using UUIDs
|
@@ -34,6 +34,24 @@ function ka_code_typed(kernel, argtypes; ndrange=nothing, workgroupsize=nothing,
|
34 | 34 | end
|
35 | 35 |
|
36 | 36 |
|
| 37 | +function ka_code_llvm(kernel, argtypes; ndrange=nothing, workgroupsize=nothing, dependencies=nothing, kwargs...) |
| 38 | + # get the iterspace and dynamic of a kernel |
| 39 | + ndrange, workgroupsize, iterspace, dynamic = KernelAbstractions.launch_config(kernel, ndrange, workgroupsize) |
| 40 | + |
| 41 | + # get the first block |
| 42 | + block = @inbounds KernelAbstractions.blocks(iterspace)[1] |
| 43 | + # get a context of the kernel based on the first block |
| 44 | + ctx = KernelAbstractions.mkcontext(kernel, block, ndrange, iterspace, dynamic) |
| 45 | + |
| 46 | + # reformat |
| 47 | + if argtypes isa Type |
| 48 | + argtypes = argtypes.parameters |
| 49 | + end |
| 50 | + # use code_typed |
| 51 | + return InteractiveUtils.code_llvm(KernelAbstractions.Cassette.overdub, (typeof(ctx), typeof(kernel.f), argtypes...); kwargs...) |
| 52 | +end |
| 53 | + |
| 54 | + |
37 | 55 | """
|
38 | 56 | Get the typed IR for a kernel
|
39 | 57 |
|
@@ -96,6 +114,66 @@ macro ka_code_typed(ex0...)
|
96 | 114 | end
|
97 | 115 |
|
98 | 116 | local results = $thecall
|
99 |
| - length(results) == 1 ? results[1] : results |
| 117 | + if results !== nothing |
| 118 | + length(results) == 1 ? results[1] : results |
| 119 | + end |
| 120 | + end |
| 121 | +end |
| 122 | + |
| 123 | + |
| 124 | +""" |
| 125 | +Get the llvm code for a kernel |
| 126 | +
|
| 127 | +# Examples |
| 128 | +``` |
| 129 | +@ka_code_llvm kernel(args. ndrange=...) |
| 130 | +@ka_code_llvm kernel(args. ndrange=... workgroupsize=...) |
| 131 | +@ka_code_llvm optimize=false kernel(args. ndrange=...) |
| 132 | +``` |
| 133 | +If ndrange is statically defined, then you could call |
| 134 | +``` |
| 135 | +@ka_code_llvm kernel(args.) |
| 136 | +``` |
| 137 | +Works for CPU kernels ONLY, with static or dynamic declarations |
| 138 | +""" |
| 139 | +macro ka_code_llvm(ex0...) |
| 140 | + ex = () |
| 141 | + args = gensym(:args) |
| 142 | + old_args = nothing |
| 143 | + kern = nothing |
| 144 | + for i = 1:length(ex0) |
| 145 | + if ex0[i].head == :call |
| 146 | + # inside kernel() expr |
| 147 | + while length(ex0[i].args) > 2 |
| 148 | + if isa(ex0[i].args[end], Expr) |
| 149 | + # at expr (like ndrange=10) |
| 150 | + kw = ex0[i].args[end] |
| 151 | + @assert kw.head == :kw |
| 152 | + kw.args[2] = esc(kw.args[2]) |
| 153 | + kw.head = Symbol("=") |
| 154 | + resize!(ex0[i].args, length(ex0[i].args) - 1) |
| 155 | + ex = (kw,)..., ex... |
| 156 | + else |
| 157 | + # only symbols left |
| 158 | + break |
| 159 | + end |
| 160 | + end |
| 161 | + # save kernel args |
| 162 | + old_args = Expr(:tuple, map(esc, ex0[i].args[2:end])...) |
| 163 | + resize!(ex0[i].args, 2) |
| 164 | + ex0[i].args[2] = Expr(:..., args) |
| 165 | + kern = esc(ex0[i].args[1]) |
| 166 | + end |
| 167 | + ex = ex..., ex0[i] |
| 168 | + end |
| 169 | + @assert(old_args != nothing) |
| 170 | + @assert(kern != nothing) |
| 171 | + |
| 172 | + thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_llvm, ex) |
| 173 | + |
| 174 | + quote |
| 175 | + local $(esc(args)) = $(old_args) |
| 176 | + |
| 177 | + local results = $thecall |
100 | 178 | end
|
101 | 179 | end
|
0 commit comments