Skip to content

Commit 91103f1

Browse files
authored
Merge pull request #58 from JuliaGPU/vc/lets
Ensure that constify doesn't cause arguments to be captured
2 parents dc5b924 + f6a535e commit 91103f1

File tree

3 files changed

+23
-28
lines changed

3 files changed

+23
-28
lines changed

src/backends/cpu.jl

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dep
5959
ndrange = nothing
6060
end
6161

62-
t = Threads.@spawn __run(obj, ndrange, iterspace, args, dependencies)
62+
t = Threads.@spawn __run(obj, ndrange, iterspace, args, dependencies, Val(dynamic))
6363
return CPUEvent(t)
6464
end
6565

6666
# Inference barriers
67-
function __run(obj, ndrange, iterspace, args, dependencies)
67+
function __run(obj, ndrange, iterspace, args, dependencies, ::Val{dynamic}) where dynamic
6868
__waitall(CPU(), dependencies, yield)
6969
N = length(iterspace)
7070
Nthreads = Threads.nthreads()
@@ -79,16 +79,16 @@ function __run(obj, ndrange, iterspace, args, dependencies)
7979
len, rem = 1, 0
8080
end
8181
if Nthreads == 1
82-
__thread_run(1, len, rem, obj, ndrange, iterspace, args)
82+
__thread_run(1, len, rem, obj, ndrange, iterspace, args, Val(dynamic))
8383
else
8484
@sync for tid in 1:Nthreads
85-
Threads.@spawn __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
85+
Threads.@spawn __thread_run(tid, len, rem, obj, ndrange, iterspace, args, Val(dynamic))
8686
end
8787
end
8888
return nothing
8989
end
9090

91-
function __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
91+
function __thread_run(tid, len, rem, obj, ndrange, iterspace, args, ::Val{dynamic}) where dynamic
9292
# compute this thread's iterations
9393
f = 1 + ((tid-1) * len)
9494
l = f + len - 1
@@ -105,27 +105,16 @@ function __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
105105
# run this thread's iterations
106106
for i = f:l
107107
block = @inbounds blocks(iterspace)[i]
108-
# TODO: how do we use the information that the iteration space maps perfectly to
109-
# the ndrange without incurring a 2x compilation overhead
110-
# if dynamic
111-
ctx = mkcontextdynamic(obj, block, ndrange, iterspace)
108+
ctx = mkcontext(obj, block, ndrange, iterspace, Val(dynamic))
112109
Cassette.overdub(ctx, obj.f, args...)
113-
# else
114-
# ctx = mkcontext(obj, blocks, ndrange, iterspace)
115-
# Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
116110
end
117111
return nothing
118112
end
119113

120114
Cassette.@context CPUCtx
121115

122-
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, iterspace)
123-
metadata = CompilerMetadata{ndrange(kernel), false}(I, _ndrange, iterspace)
124-
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
125-
end
126-
127-
function mkcontextdynamic(kernel::Kernel{CPU}, I, _ndrange, iterspace)
128-
metadata = CompilerMetadata{ndrange(kernel), true}(I, _ndrange, iterspace)
116+
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, iterspace, ::Val{dynamic}) where dynamic
117+
metadata = CompilerMetadata{ndrange(kernel), dynamic}(I, _ndrange, iterspace)
129118
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
130119
end
131120

src/macros.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,23 @@ end
5757
# The easy case, transform the function for GPU execution
5858
# - mark constant arguments by applying `constify`.
5959
function transform_gpu!(def, constargs)
60-
new_stmts = Expr[]
60+
let_constargs = Expr[]
6161
for (i, arg) in enumerate(def[:args])
6262
if constargs[i]
63-
push!(new_stmts, :($arg = $constify($arg)))
63+
push!(let_constargs, :($arg = $constify($arg)))
6464
end
6565
end
6666

67-
def[:body] = quote
67+
body = quote
6868
if $__validindex()
69-
$(new_stmts...)
7069
$(def[:body])
7170
end
7271
return nothing
7372
end
73+
def[:body] = Expr(:let,
74+
Expr(:block, let_constargs...),
75+
body,
76+
)
7477
end
7578

7679
# The hard case, transform the function for CPU execution
@@ -81,19 +84,22 @@ end
8184
# - hoist workgroup definitions
8285
# - hoist uniform variables
8386
function transform_cpu!(def, constargs)
84-
new_stmts = Expr[]
87+
let_constargs = Expr[]
8588
for (i, arg) in enumerate(def[:args])
8689
if constargs[i]
87-
push!(new_stmts, :($arg = $constify($arg)))
90+
push!(let_constargs, :($arg = $constify($arg)))
8891
end
8992
end
90-
93+
new_stmts = Expr[]
9194
body = MacroTools.flatten(def[:body])
9295
push!(new_stmts, Expr(:aliasscope))
9396
append!(new_stmts, split(body.args))
9497
push!(new_stmts, Expr(:popaliasscope))
9598
push!(new_stmts, :(return nothing))
96-
def[:body] = Expr(:block, new_stmts...)
99+
def[:body] = Expr(:let,
100+
Expr(:block, let_constargs...),
101+
Expr(:block, new_stmts...)
102+
)
97103
end
98104

99105
struct WorkgroupLoop

test/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ end
121121
let kernel = constarg(CPU(), 8, (1024,))
122122
# this is poking at internals
123123
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
124-
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace)
124+
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(false))
125125
AT = Array{Float32, 2}
126126
IR = sprint() do io
127127
code_llvm(io, KernelAbstractions.Cassette.overdub,

0 commit comments

Comments
 (0)