Skip to content

Commit ad1a546

Browse files
committed
Add an experimental opaque closure type.
1 parent 1e5f0d6 commit ad1a546

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

src/compiler/compilation.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,121 @@ function run_and_collect(cmd)
425425

426426
return proc, log
427427
end
428+
429+
430+
431+
## opaque closures
432+
433+
# TODO: once stabilised, move bits of this into GPUCompiler.jl
434+
435+
using Core.Compiler: IRCode
436+
using Core: CodeInfo, MethodInstance, CodeInstance, LineNumberNode
437+
438+
struct OpaqueClosure{F, E, A, R} # func, env, args, ret
439+
env::E
440+
end
441+
442+
# XXX: because we can't call functions from other CUDA modules, we effectively need to
443+
# recompile when the target function changes. this, and because of how GPUCompiler's
444+
# deferred compilation mechanism currently works, is why we have `F` as a type param.
445+
446+
# XXX: because of GPU code requiring specialized signatures, we also need to recompile
447+
# when the environment or argument types change. together with the above, this
448+
# negates much of the benefit of opaque closures.
449+
450+
# TODO: support for constructing an opaque closure from source code
451+
452+
# TODO: complete support for passing an environment. this probably requires a split into
453+
# host and device structures to, e.g., root a CuArray and pass a CuDeviceArray.
454+
455+
function compute_ir_rettype(ir::IRCode)
456+
rt = Union{}
457+
for i = 1:length(ir.stmts)
458+
stmt = ir.stmts[i][:inst]
459+
if isa(stmt, Core.Compiler.ReturnNode) && isdefined(stmt, :val)
460+
rt = Core.Compiler.tmerge(Core.Compiler.argextype(stmt.val, ir), rt)
461+
end
462+
end
463+
return Core.Compiler.widenconst(rt)
464+
end
465+
466+
function compute_oc_signature(ir::IRCode, nargs::Int, isva::Bool)
467+
argtypes = Vector{Any}(undef, nargs)
468+
for i = 1:nargs
469+
argtypes[i] = Core.Compiler.widenconst(ir.argtypes[i+1])
470+
end
471+
if isva
472+
lastarg = pop!(argtypes)
473+
if lastarg <: Tuple
474+
append!(argtypes, lastarg.parameters)
475+
else
476+
push!(argtypes, Vararg{Any})
477+
end
478+
end
479+
return Tuple{argtypes...}
480+
end
481+
482+
function OpaqueClosure(ir::IRCode, @nospecialize env...; isva::Bool = false)
483+
# NOTE: we need ir.argtypes[1] == typeof(env)
484+
ir = Core.Compiler.copy(ir)
485+
nargs = length(ir.argtypes)-1
486+
sig = compute_oc_signature(ir, nargs, isva)
487+
rt = compute_ir_rettype(ir)
488+
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
489+
src.slotnames = Base.fill(:none, nargs+1)
490+
src.slotflags = Base.fill(zero(UInt8), length(ir.argtypes))
491+
src.slottypes = copy(ir.argtypes)
492+
src.rettype = rt
493+
src = Core.Compiler.ir_to_codeinf!(src, ir)
494+
config = compiler_config(device(); kernel=false)
495+
return generate_opaque_closure(config, src, sig, rt, nargs, isva, env...)
496+
end
497+
498+
function OpaqueGPUClosure(src::CodeInfo, @nospecialize env...)
499+
src.inferred || throw(ArgumentError("Expected inferred src::CodeInfo"))
500+
mi = src.parent::Core.MethodInstance
501+
sig = Base.tuple_type_tail(mi.specTypes)
502+
method = mi.def::Method
503+
nargs = method.nargs-1
504+
isva = method.isva
505+
return generate_opaque_closure(config, src, sig, src.rettype, nargs, isva, env...)
506+
end
507+
508+
function generate_opaque_closure(config::CompilerConfig, src::CodeInfo,
509+
@nospecialize(sig), @nospecialize(rt),
510+
nargs::Int, isva::Bool, @nospecialize env...;
511+
mod::Module=@__MODULE__,
512+
file::Union{Nothing,Symbol}=nothing, line::Int=0)
513+
# create a method (like `jl_make_opaque_closure_method`)
514+
meth = ccall(:jl_new_method_uninit, Ref{Method}, (Any,), Main)
515+
meth.sig = Tuple
516+
meth.isva = isva # XXX: probably not supported?
517+
meth.is_for_opaque_closure = 0 # XXX: do we want this?
518+
meth.name = Symbol("opaque gpu closure")
519+
meth.nargs = nargs + 1
520+
meth.file = something(file, Symbol())
521+
meth.line = line
522+
ccall(:jl_method_set_source, Nothing, (Any, Any), meth, src)
523+
524+
# look up a method instance and create a compiler job
525+
full_sig = Tuple{typeof(env), sig.parameters...}
526+
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance},
527+
(Any, Any, Any), meth, full_sig, Core.svec())
528+
job = CompilerJob(mi, config) # this captures the current world age
529+
530+
# create a code instance and store it in the cache
531+
ci = CodeInstance(mi, rt, C_NULL, src, Int32(0), meth.primary_world, typemax(UInt),
532+
UInt32(0), UInt32(0), nothing, UInt8(0))
533+
Core.Compiler.setindex!(GPUCompiler.ci_cache(job), ci, mi)
534+
535+
id = length(GPUCompiler.deferred_codegen_jobs) + 1
536+
GPUCompiler.deferred_codegen_jobs[id] = job
537+
return OpaqueClosure{id, typeof(env), sig, rt}(env)
538+
end
539+
540+
# device-side call to an opaque closure
541+
function (oc::OpaqueClosure{F})(a, b) where F
542+
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), F)
543+
assume(ptr != C_NULL)
544+
return ccall(ptr, Int, (Int, Int), a, b)
545+
end

test/core/execution.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,3 +1093,49 @@ end
10931093
end
10941094

10951095
############################################################################################
1096+
1097+
if VERSION >= v"1.10-"
1098+
@testset "opaque closures" begin
1099+
1100+
# basic closure, constructed from IRCode
1101+
let
1102+
ir, rettyp = only(Base.code_ircode(+, (Int, Int)))
1103+
oc = CUDA.OpaqueClosure(ir)
1104+
1105+
c = CuArray([0])
1106+
a = CuArray([1])
1107+
b = CuArray([2])
1108+
1109+
function kernel(oc, c, a, b)
1110+
i = threadIdx().x
1111+
@inbounds c[i] = oc(a[i], b[i])
1112+
return
1113+
end
1114+
@cuda threads=1 kernel(oc, c, a, b)
1115+
1116+
@test Array(c)[] == 3
1117+
end
1118+
1119+
# basic closure, constructed from CodeInfo
1120+
let
1121+
ir, rettyp = only(Base.code_typed(+, (Int, Int)))
1122+
oc = CUDA.OpaqueClosure(ir)
1123+
1124+
c = CuArray([0])
1125+
a = CuArray([1])
1126+
b = CuArray([2])
1127+
1128+
function kernel(oc, c, a, b)
1129+
i = threadIdx().x
1130+
@inbounds c[i] = oc(a[i], b[i])
1131+
return
1132+
end
1133+
@cuda threads=1 kernel(oc, c, a, b)
1134+
1135+
@test Array(c)[] == 3
1136+
end
1137+
1138+
end
1139+
end
1140+
1141+
############################################################################################

0 commit comments

Comments
 (0)