Skip to content

Commit 1c46e15

Browse files
authored
@ka_code_typed implementation for KernelAbstractions
1 parent 8c7052d commit 1c46e15

File tree

4 files changed

+71
-15
lines changed

4 files changed

+71
-15
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ version = "0.4.1"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
10+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
11+
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1012
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1113
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1214
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

src/KernelAbstractions.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ instead.
146146
See also [`@uniform`](@ref).
147147
"""
148148
macro private(T, dims)
149+
if dims isa Integer
150+
dims = (dims,)
151+
end
149152
quote
150153
$Scratchpad($(esc(T)), Val($(esc(dims))))
151154
end
@@ -384,7 +387,7 @@ function partition(kernel, ndrange, workgroupsize)
384387
if static_ndrange <: StaticSize
385388
static_blocks = StaticSize{blocks}
386389
blocks = nothing
387-
else
390+
else
388391
static_blocks = DynamicSize
389392
blocks = CartesianIndices(blocks)
390393
end
@@ -466,4 +469,6 @@ include("backends/cuda.jl")
466469
###
467470

468471
include("extras/extras.jl")
472+
473+
include("reflection.jl")
469474
end #module

src/backends/cpu.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function wait(cpu::CPU, ev::MultiEvent, progress=nothing)
3838
alldone = ntuple(N) do i
3939
if alldone[i]
4040
true
41-
else
41+
else
4242
isdone(events[i])
4343
end
4444
end
@@ -48,7 +48,7 @@ function wait(cpu::CPU, ev::MultiEvent, progress=nothing)
4848
progress()
4949
end
5050
end
51-
51+
5252
if any(failed, events)
5353
ex = CompositeException()
5454
for event in events
@@ -78,31 +78,38 @@ function async_copy!(::CPU, A, B; dependencies=nothing, progress=nothing)
7878
end
7979

8080
function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dependencies=nothing, progress=nothing)
81+
ndrange, workgroupsize, iterspace, dynamic = launch_config(obj, ndrange, workgroupsize)
82+
83+
if dependencies isa Event
84+
dependencies = (dependencies,)
85+
end
86+
87+
if length(blocks(iterspace)) == 0
88+
return MultiEvent(dependencies)
89+
end
90+
91+
Event(__run, obj, ndrange, iterspace, args, dynamic,
92+
dependencies=dependencies, progress=progress)
93+
end
94+
95+
function launch_config(kernel::Kernel{CPU}, ndrange, workgroupsize)
8196
if ndrange isa Integer
8297
ndrange = (ndrange,)
8398
end
8499
if workgroupsize isa Integer
85100
workgroupsize = (workgroupsize, )
86101
end
87-
if dependencies isa Event
88-
dependencies = (dependencies,)
89-
end
90102

91-
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
103+
if KernelAbstractions.workgroupsize(kernel) <: DynamicSize && workgroupsize === nothing
92104
workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size
93105
end
94-
iterspace, dynamic = partition(obj, ndrange, workgroupsize)
106+
iterspace, dynamic = partition(kernel, ndrange, workgroupsize)
95107
# partition checked that the ndrange's agreed
96-
if KernelAbstractions.ndrange(obj) <: StaticSize
108+
if KernelAbstractions.ndrange(kernel) <: StaticSize
97109
ndrange = nothing
98110
end
99111

100-
if length(blocks(iterspace)) == 0
101-
return MultiEvent(dependencies)
102-
end
103-
104-
Event(__run, obj, ndrange, iterspace, args, dynamic,
105-
dependencies=dependencies, progress=progress)
112+
return ndrange, workgroupsize, iterspace, dynamic
106113
end
107114

108115
# Inference barriers

src/reflection.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import InteractiveUtils
2+
export @ka_code_typed
3+
4+
function ka_code_typed(kernel, argtypes; ndrange=nothing, workgroupsize=nothing, dependencies=nothing, kwargs...)
5+
# get the iterspace and dynamic of a kernel
6+
ndrange, workgroupsize, iterspace, dynamic = KernelAbstractions.launch_config(kernel, ndrange, workgroupsize)
7+
# get the first block
8+
block = @inbounds KernelAbstractions.blocks(iterspace)[1]
9+
# get a context of the kernel based on the first block
10+
ctx = KernelAbstractions.mkcontext(kernel, block, ndrange, iterspace, dynamic)
11+
# reformat
12+
if argtypes isa Type
13+
argtypes = argtypes.parameters
14+
end
15+
# use code_typed
16+
return InteractiveUtils.code_typed(KernelAbstractions.Cassette.overdub, (typeof(ctx), typeof(kernel.f), argtypes...); kwargs...)
17+
end
18+
19+
20+
macro ka_code_typed(ex0...)
21+
ex = ()
22+
for i = 1:length(ex0)
23+
if ex0[i].head == :call
24+
while length(ex0[i].args) > 2
25+
kw = ex0[i].args[end]
26+
@assert kw.head == :kw
27+
kw.args[2] = esc(kw.args[2])
28+
kw.head = Symbol("=")
29+
resize!(ex0[i].args, length(ex0[i].args) - 1)
30+
ex = (kw,)..., ex...
31+
end
32+
end
33+
ex = ex..., ex0[i]
34+
end
35+
36+
thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_typed, ex)
37+
38+
quote
39+
local results = $thecall
40+
length(results) == 1 ? results[1] : results
41+
end
42+
end

0 commit comments

Comments
 (0)