Skip to content

Commit 4b62e16

Browse files
committed
Generalize to arbitrary args.
1 parent 59ac05e commit 4b62e16

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

src/compiler/compilation.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,41 @@ function generate_opaque_closure(config::CompilerConfig, src::CodeInfo,
370370
return OpaqueClosure{id, typeof(env), sig, rt}(env)
371371
end
372372

373+
# generated function `ccall`, working around the restriction that ccall type
374+
# tuples need to be literals. this relies on ccall internals...
375+
@inline @generated function generated_ccall(f::Ptr, _rettyp, _types, vals...)
376+
ex = quote end
377+
378+
rettyp = _rettyp.parameters[1]
379+
types = _types.parameters[1].parameters
380+
args = [:(vals[$i]) for i in 1:length(vals)]
381+
382+
# cconvert
383+
cconverted = [Symbol("cconverted_$i") for i in 1:length(vals)]
384+
for (dst, typ, src) in zip(cconverted, types, args)
385+
append!(ex.args, (quote
386+
$dst = Base.cconvert($typ, $src)
387+
end).args)
388+
end
389+
390+
# unsafe_convert
391+
unsafe_converted = [Symbol("unsafe_converted_$i") for i in 1:length(vals)]
392+
for (dst, typ, src) in zip(unsafe_converted, types, cconverted)
393+
append!(ex.args, (quote
394+
$dst = Base.unsafe_convert($typ, $src)
395+
end).args)
396+
end
397+
398+
call = Expr(:foreigncall, :f, rettyp, Core.svec(types...), 0,
399+
QuoteNode(:ccall), unsafe_converted..., cconverted...)
400+
push!(ex.args, call)
401+
return ex
402+
end
403+
373404
# device-side call to an opaque closure
374-
function (oc::OpaqueClosure{F})(a, b) where F
405+
function (oc::OpaqueClosure{F,E,A,R})(args...) where {F,E,A,R}
375406
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), F)
376407
assume(ptr != C_NULL)
377-
return ccall(ptr, Int, (Int, Int), a, b)
408+
#ccall(ptr, R, (A...), args...)
409+
generated_ccall(ptr, R, A, args...)
378410
end

test/core/execution.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,21 +1110,22 @@ end
11101110

11111111
# basic closure, constructed from CodeInfo
11121112
let
1113-
ir, rettyp = only(Base.code_typed(+, (Int, Int)))
1113+
ir, rettyp = only(Base.code_typed(*, (Int, Int, Int)))
11141114
oc = CUDA.OpaqueClosure(ir)
11151115

1116-
c = CuArray([0])
1117-
a = CuArray([1])
1118-
b = CuArray([2])
1116+
d = CuArray([1])
1117+
a = CuArray([2])
1118+
b = CuArray([3])
1119+
c = CuArray([4])
11191120

1120-
function kernel(oc, c, a, b)
1121+
function kernel(oc, d, a, b, c)
11211122
i = threadIdx().x
1122-
@inbounds c[i] = oc(a[i], b[i])
1123+
@inbounds d[i] = oc(a[i], b[i], c[i])
11231124
return
11241125
end
1125-
@cuda threads=1 kernel(oc, c, a, b)
1126+
@cuda threads=1 kernel(oc, d, a, b, c)
11261127

1127-
@test Array(c)[] == 3
1128+
@test Array(d)[] == 24
11281129
end
11291130

11301131
end

0 commit comments

Comments
 (0)