Skip to content

Commit 0f31c52

Browse files
committed
propagate whether we need dynamic boundschecking
1 parent dc5b924 commit 0f31c52

File tree

2 files changed

+9
-20
lines changed

2 files changed

+9
-20
lines changed

src/backends/cpu.jl

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dep
5959
ndrange = nothing
6060
end
6161

62-
t = Threads.@spawn __run(obj, ndrange, iterspace, args, dependencies)
62+
t = Threads.@spawn __run(obj, ndrange, iterspace, args, dependencies, Val(dynamic))
6363
return CPUEvent(t)
6464
end
6565

6666
# Inference barriers
67-
function __run(obj, ndrange, iterspace, args, dependencies)
67+
function __run(obj, ndrange, iterspace, args, dependencies, ::Val{dynamic}) where dynamic
6868
__waitall(CPU(), dependencies, yield)
6969
N = length(iterspace)
7070
Nthreads = Threads.nthreads()
@@ -79,16 +79,16 @@ function __run(obj, ndrange, iterspace, args, dependencies)
7979
len, rem = 1, 0
8080
end
8181
if Nthreads == 1
82-
__thread_run(1, len, rem, obj, ndrange, iterspace, args)
82+
__thread_run(1, len, rem, obj, ndrange, iterspace, args, Val(dynamic))
8383
else
8484
@sync for tid in 1:Nthreads
85-
Threads.@spawn __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
85+
Threads.@spawn __thread_run(tid, len, rem, obj, ndrange, iterspace, args, Val(dynamic))
8686
end
8787
end
8888
return nothing
8989
end
9090

91-
function __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
91+
function __thread_run(tid, len, rem, obj, ndrange, iterspace, args, ::Val{dynamic}) where dynamic
9292
# compute this thread's iterations
9393
f = 1 + ((tid-1) * len)
9494
l = f + len - 1
@@ -105,27 +105,16 @@ function __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
105105
# run this thread's iterations
106106
for i = f:l
107107
block = @inbounds blocks(iterspace)[i]
108-
# TODO: how do we use the information that the iteration space maps perfectly to
109-
# the ndrange without incurring a 2x compilation overhead
110-
# if dynamic
111-
ctx = mkcontextdynamic(obj, block, ndrange, iterspace)
108+
ctx = mkcontext(obj, block, ndrange, iterspace, Val(dynamic))
112109
Cassette.overdub(ctx, obj.f, args...)
113-
# else
114-
# ctx = mkcontext(obj, blocks, ndrange, iterspace)
115-
# Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
116110
end
117111
return nothing
118112
end
119113

120114
Cassette.@context CPUCtx
121115

122-
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, iterspace)
123-
metadata = CompilerMetadata{ndrange(kernel), false}(I, _ndrange, iterspace)
124-
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
125-
end
126-
127-
function mkcontextdynamic(kernel::Kernel{CPU}, I, _ndrange, iterspace)
128-
metadata = CompilerMetadata{ndrange(kernel), true}(I, _ndrange, iterspace)
116+
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, iterspace, ::Val{dynamic}) where dynamic
117+
metadata = CompilerMetadata{ndrange(kernel), dynamic}(I, _ndrange, iterspace)
129118
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
130119
end
131120

test/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ end
121121
let kernel = constarg(CPU(), 8, (1024,))
122122
# this is poking at internals
123123
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
124-
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace)
124+
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(false))
125125
AT = Array{Float32, 2}
126126
IR = sprint() do io
127127
code_llvm(io, KernelAbstractions.Cassette.overdub,

0 commit comments

Comments
 (0)