Skip to content

Commit 26553a4

Browse files
committed
Generalize to arbitrary args.
1 parent 81a8143 commit 26553a4

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

src/compiler/compilation.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,41 @@ function generate_opaque_closure(config::CompilerConfig, src::CodeInfo,
374374
return OpaqueClosure{id, typeof(env), sig, rt}(env)
375375
end
376376

377+
# generated function `ccall`, working around the restriction that ccall type
378+
# tuples need to be literals. this relies on ccall internals...
379+
@inline @generated function generated_ccall(f::Ptr, _rettyp, _types, vals...)
380+
ex = quote end
381+
382+
rettyp = _rettyp.parameters[1]
383+
types = _types.parameters[1].parameters
384+
args = [:(vals[$i]) for i in 1:length(vals)]
385+
386+
# cconvert
387+
cconverted = [Symbol("cconverted_$i") for i in 1:length(vals)]
388+
for (dst, typ, src) in zip(cconverted, types, args)
389+
append!(ex.args, (quote
390+
$dst = Base.cconvert($typ, $src)
391+
end).args)
392+
end
393+
394+
# unsafe_convert
395+
unsafe_converted = [Symbol("unsafe_converted_$i") for i in 1:length(vals)]
396+
for (dst, typ, src) in zip(unsafe_converted, types, cconverted)
397+
append!(ex.args, (quote
398+
$dst = Base.unsafe_convert($typ, $src)
399+
end).args)
400+
end
401+
402+
call = Expr(:foreigncall, :f, rettyp, Core.svec(types...), 0,
403+
QuoteNode(:ccall), unsafe_converted..., cconverted...)
404+
push!(ex.args, call)
405+
return ex
406+
end
407+
377408
# device-side call to an opaque closure
378-
function (oc::OpaqueClosure{F})(a, b) where F
409+
function (oc::OpaqueClosure{F,E,A,R})(args...) where {F,E,A,R}
379410
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), F)
380411
assume(ptr != C_NULL)
381-
return ccall(ptr, Int, (Int, Int), a, b)
412+
#ccall(ptr, R, (A...), args...)
413+
generated_ccall(ptr, R, A, args...)
382414
end

test/execution.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,21 +1110,22 @@ end
11101110

11111111
# basic closure, constructed from CodeInfo
11121112
let
1113-
ir, rettyp = only(Base.code_typed(+, (Int, Int)))
1113+
ir, rettyp = only(Base.code_typed(*, (Int, Int, Int)))
11141114
oc = CUDA.OpaqueClosure(ir)
11151115

1116-
c = CuArray([0])
1117-
a = CuArray([1])
1118-
b = CuArray([2])
1116+
d = CuArray([1])
1117+
a = CuArray([2])
1118+
b = CuArray([3])
1119+
c = CuArray([4])
11191120

1120-
function kernel(oc, c, a, b)
1121+
function kernel(oc, d, a, b, c)
11211122
i = threadIdx().x
1122-
@inbounds c[i] = oc(a[i], b[i])
1123+
@inbounds d[i] = oc(a[i], b[i], c[i])
11231124
return
11241125
end
1125-
@cuda threads=1 kernel(oc, c, a, b)
1126+
@cuda threads=1 kernel(oc, d, a, b, c)
11261127

1127-
@test Array(c)[] == 3
1128+
@test Array(d)[] == 24
11281129
end
11291130

11301131
end

0 commit comments

Comments
 (0)