Skip to content

Commit 514c506

Browse files
committed
Generalize to arbitrary args.
1 parent 71a0515 commit 514c506

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

src/compiler/compilation.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,41 @@ function generate_opaque_closure(config::CompilerConfig, src::CodeInfo,
538538
return OpaqueClosure{id, typeof(env), sig, rt}(env)
539539
end
540540

541+
# generated function `ccall`, working around the restriction that ccall type
542+
# tuples need to be literals. this relies on ccall internals...
543+
@inline @generated function generated_ccall(f::Ptr, _rettyp, _types, vals...)
544+
ex = quote end
545+
546+
rettyp = _rettyp.parameters[1]
547+
types = _types.parameters[1].parameters
548+
args = [:(vals[$i]) for i in 1:length(vals)]
549+
550+
# cconvert
551+
cconverted = [Symbol("cconverted_$i") for i in 1:length(vals)]
552+
for (dst, typ, src) in zip(cconverted, types, args)
553+
append!(ex.args, (quote
554+
$dst = Base.cconvert($typ, $src)
555+
end).args)
556+
end
557+
558+
# unsafe_convert
559+
unsafe_converted = [Symbol("unsafe_converted_$i") for i in 1:length(vals)]
560+
for (dst, typ, src) in zip(unsafe_converted, types, cconverted)
561+
append!(ex.args, (quote
562+
$dst = Base.unsafe_convert($typ, $src)
563+
end).args)
564+
end
565+
566+
call = Expr(:foreigncall, :f, rettyp, Core.svec(types...), 0,
567+
QuoteNode(:ccall), unsafe_converted..., cconverted...)
568+
push!(ex.args, call)
569+
return ex
570+
end
571+
541572
# device-side call to an opaque closure
542-
function (oc::OpaqueClosure{F})(a, b) where F
573+
function (oc::OpaqueClosure{F,E,A,R})(args...) where {F,E,A,R}
543574
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), F)
544575
assume(ptr != C_NULL)
545-
return ccall(ptr, Int, (Int, Int), a, b)
576+
#ccall(ptr, R, (A...), args...)
577+
generated_ccall(ptr, R, A, args...)
546578
end

test/core/execution.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,21 +1118,22 @@ end
11181118

11191119
# basic closure, constructed from CodeInfo
11201120
let
1121-
ir, rettyp = only(Base.code_typed(+, (Int, Int)))
1121+
ir, rettyp = only(Base.code_typed(*, (Int, Int, Int)))
11221122
oc = CUDA.OpaqueClosure(ir)
11231123

1124-
c = CuArray([0])
1125-
a = CuArray([1])
1126-
b = CuArray([2])
1124+
d = CuArray([1])
1125+
a = CuArray([2])
1126+
b = CuArray([3])
1127+
c = CuArray([4])
11271128

1128-
function kernel(oc, c, a, b)
1129+
function kernel(oc, d, a, b, c)
11291130
i = threadIdx().x
1130-
@inbounds c[i] = oc(a[i], b[i])
1131+
@inbounds d[i] = oc(a[i], b[i], c[i])
11311132
return
11321133
end
1133-
@cuda threads=1 kernel(oc, c, a, b)
1134+
@cuda threads=1 kernel(oc, d, a, b, c)
11341135

1135-
@test Array(c)[] == 3
1136+
@test Array(d)[] == 24
11361137
end
11371138

11381139
end

0 commit comments

Comments
 (0)