Skip to content

Commit a9a3be2

Browse files
committed
Generalize to arbitrary args.
1 parent 243dbac commit a9a3be2

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

src/compiler/compilation.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,41 @@ function generate_opaque_closure(config::CompilerConfig, src::CodeInfo,
538538
return OpaqueClosure{id, typeof(env), sig, rt}(env)
539539
end
540540

541+
# generated function `ccall`, working around the restriction that ccall type
542+
# tuples need to be literals. this relies on ccall internals...
543+
@inline @generated function generated_ccall(f::Ptr, _rettyp, _types, vals...)
544+
ex = quote end
545+
546+
rettyp = _rettyp.parameters[1]
547+
types = _types.parameters[1].parameters
548+
args = [:(vals[$i]) for i in 1:length(vals)]
549+
550+
# cconvert
551+
cconverted = [Symbol("cconverted_$i") for i in 1:length(vals)]
552+
for (dst, typ, src) in zip(cconverted, types, args)
553+
append!(ex.args, (quote
554+
$dst = Base.cconvert($typ, $src)
555+
end).args)
556+
end
557+
558+
# unsafe_convert
559+
unsafe_converted = [Symbol("unsafe_converted_$i") for i in 1:length(vals)]
560+
for (dst, typ, src) in zip(unsafe_converted, types, cconverted)
561+
append!(ex.args, (quote
562+
$dst = Base.unsafe_convert($typ, $src)
563+
end).args)
564+
end
565+
566+
call = Expr(:foreigncall, :f, rettyp, Core.svec(types...), 0,
567+
QuoteNode(:ccall), unsafe_converted..., cconverted...)
568+
push!(ex.args, call)
569+
return ex
570+
end
571+
541572
# device-side call to an opaque closure
542-
function (oc::OpaqueClosure{F})(a, b) where F
573+
function (oc::OpaqueClosure{F,E,A,R})(args...) where {F,E,A,R}
543574
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), F)
544575
assume(ptr != C_NULL)
545-
return ccall(ptr, Int, (Int, Int), a, b)
576+
#ccall(ptr, R, (A...), args...)
577+
generated_ccall(ptr, R, A, args...)
546578
end

test/core/execution.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,21 +1120,22 @@ end
11201120

11211121
# basic closure, constructed from CodeInfo
11221122
let
1123-
ir, rettyp = only(Base.code_typed(+, (Int, Int)))
1123+
ir, rettyp = only(Base.code_typed(*, (Int, Int, Int)))
11241124
oc = CUDA.OpaqueClosure(ir)
11251125

1126-
c = CuArray([0])
1127-
a = CuArray([1])
1128-
b = CuArray([2])
1126+
d = CuArray([1])
1127+
a = CuArray([2])
1128+
b = CuArray([3])
1129+
c = CuArray([4])
11291130

1130-
function kernel(oc, c, a, b)
1131+
function kernel(oc, d, a, b, c)
11311132
i = threadIdx().x
1132-
@inbounds c[i] = oc(a[i], b[i])
1133+
@inbounds d[i] = oc(a[i], b[i], c[i])
11331134
return
11341135
end
1135-
@cuda threads=1 kernel(oc, c, a, b)
1136+
@cuda threads=1 kernel(oc, d, a, b, c)
11361137

1137-
@test Array(c)[] == 3
1138+
@test Array(d)[] == 24
11381139
end
11391140

11401141
end

0 commit comments

Comments
 (0)