Skip to content

Commit 02caa17

Browse files
authored
Mark ccalls as gc safe (#605)
1 parent dbad788 commit 02caa17

File tree

8 files changed

+199
-125
lines changed

8 files changed

+199
-125
lines changed

gen/hip/generator.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# NOTE: for this to work, in /opt/rocm/include/hip/hip_runtime_api.h
2+
# add the following:
3+
# #define __HIP_PLATFORM_AMD__
4+
# right before:
5+
# #if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__)
6+
7+
using Clang.Generators
8+
9+
include_dir = normpath("/opt/rocm/include")
10+
hip_dir = joinpath(include_dir, "hip")
11+
options = load_options("hip/hip-generator.toml")
12+
13+
args = get_default_args()
14+
push!(args, "-I$include_dir")
15+
push!(args, "-I$hip_dir")
16+
17+
headers = [
18+
joinpath(hip_dir, header)
19+
for header in readdir(hip_dir)
20+
if header == "hip_runtime_api.h"
21+
]
22+
23+
ctx = create_context(headers, args, options)
24+
build!(ctx)

gen/hip/hip-generator.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[general]
2+
library_name = "libhip"
3+
output_file_path = "./libhip.jl"
4+
5+
[codegen]
6+
use_ccall_macro = true
7+
always_NUL_terminated_string = true

src/hip/HIP.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import PrettyTables
77
import ..AMDGPU
88
import ..AMDGPU.libhip
99

10+
include("call.jl")
1011
include("libhip_common.jl")
1112
include("error.jl")
1213
include("libhip.jl")
@@ -35,12 +36,7 @@ function HIPContext(device::HIPDevice)
3536
context_ref = Ref{hipContext_t}()
3637
hipCtxCreate(context_ref, Cuint(0), device.device) |> check
3738
context = HIPContext(context_ref[], true)
38-
3939
device!(device)
40-
finalizer(context) do c
41-
c.valid = false
42-
hipCtxDestroy(c.context) |> check
43-
end
4440
return context
4541
end
4642
end

src/hip/call.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
## version of ccall that calls jl_gc_safe_enter|leave around the inner ccall
2+
3+
# TODO: replace with JuliaLang/julia#49933 once merged
4+
5+
# note that this is generally only safe with functions that do not call back into Julia.
6+
# when callbacks occur, the code should ensure the GC is not running by wrapping the code
7+
# in the `@gcunsafe` macro
8+
9+
function ccall_macro_lower(func, rettype, types, args, nreq)
10+
# instead of re-using ccall or Expr(:foreigncall) to perform argument conversion,
11+
# we need to do so ourselves in order to insert a jl_gc_safe_enter|leave
12+
# just around the inner ccall
13+
14+
cconvert_exprs = []
15+
cconvert_args = []
16+
for (typ, arg) in zip(types, args)
17+
var = gensym("$(func)_cconvert")
18+
push!(cconvert_args, var)
19+
push!(cconvert_exprs, quote
20+
$var = Base.cconvert($(esc(typ)), $(esc(arg)))
21+
end)
22+
end
23+
24+
unsafe_convert_exprs = []
25+
unsafe_convert_args = []
26+
for (typ, arg) in zip(types, cconvert_args)
27+
var = gensym("$(func)_unsafe_convert")
28+
push!(unsafe_convert_args, var)
29+
push!(unsafe_convert_exprs, quote
30+
$var = Base.unsafe_convert($(esc(typ)), $arg)
31+
end)
32+
end
33+
34+
call = quote
35+
$(unsafe_convert_exprs...)
36+
37+
gc_state = @ccall(jl_gc_safe_enter()::Int8)
38+
ret = ccall($(esc(func)), $(esc(rettype)), $(Expr(:tuple, map(esc, types)...)),
39+
$(unsafe_convert_args...))
40+
@ccall(jl_gc_safe_leave(gc_state::Int8)::Cvoid)
41+
ret
42+
end
43+
44+
quote
45+
$(cconvert_exprs...)
46+
47+
GC.@preserve $(cconvert_args...) $(call)
48+
end
49+
end
50+
51+
"""
52+
@gcsafe_ccall ...
53+
54+
Call a foreign function just like `@ccall`, but marking it safe for the GC to run. This is
55+
useful for functions that may block, so that the GC isn't blocked from running, but may also
56+
be required to prevent deadlocks (see JuliaGPU/CUDA.jl#2261).
57+
58+
Note that this is generally only safe with non-Julia C functions that do not call back into
59+
Julia. When using callbacks, the code should make sure to transition back into GC-unsafe
60+
mode using the `@gcunsafe` macro.
61+
"""
62+
macro gcsafe_ccall(expr)
63+
ccall_macro_lower(Base.ccall_macro_parse(expr)...)
64+
end

0 commit comments

Comments
 (0)