Skip to content

Commit 0c62908

Browse files
committed
[Orc] Implement default resolver
1 parent dd7a7ac commit 0c62908

File tree

2 files changed

+189
-11
lines changed

2 files changed

+189
-11
lines changed

src/orc.jl

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,52 @@ struct OrcModule
4242
end
4343
Base.convert(::Type{API.LLVMOrcModuleHandle}, mod::OrcModule) = mod.handle
4444

45-
function compile!(orc::OrcJIT, mod::Module, resolver = C_NULL, ctx = C_NULL; lazy=false)
45+
"""
46+
resolver(name, ctx)
47+
48+
Lookup the symbol `name`. Iff `ctx` is passed to this function it should be a
49+
pointer to the OrcJIT we are compiling for.
50+
"""
51+
function resolver(name, ctx)
52+
name = unsafe_string(name)
53+
## Step 0: Should have already resolved it iff it was in the
54+
## same module
55+
## Step 1: See if it's something known to the execution engine
56+
ptr = C_NULL
57+
if ctx != C_NULL
58+
orc = OrcJIT(ctx)
59+
ptr = pointer(address(orc, name))
60+
end
61+
62+
## Step 2: Search the program symbols
63+
if ptr == C_NULL
64+
#
65+
# SearchForAddressOfSymbol expects an unmangled 'C' symbol name.
66+
# Iff we are on Darwin, strip the leading '_' off.
67+
@static if Sys.isapple()
68+
if name[1] == '_'
69+
name = name[2:end]
70+
end
71+
end
72+
ptr = LLVM.find_symbol(name)
73+
end
74+
75+
## Step 4: Lookup in libatomic
76+
# TODO: Do we need to do this?
77+
78+
if ptr == C_NULL
79+
error("OrcJIT: Symbol `$name` lookup failed. Aborting!")
80+
end
81+
82+
return UInt64(reinterpret(UInt, ptr))
83+
end
84+
85+
function compile!(orc::OrcJIT, mod::Module, resolver = @cfunction(resolver, UInt64, (Cstring, Ptr{Cvoid})), resolver_ctx = orc; lazy=false)
4686
r_mod = Ref{API.LLVMOrcModuleHandle}()
4787
if lazy
48-
API.LLVMOrcAddLazilyCompiledIR(orc, r_mod, mod, resolver, ctx)
88+
API.LLVMOrcAddLazilyCompiledIR(orc, r_mod, mod, resolver, resolver_ctx)
4989
else
50-
API.LLVMOrcAddEagerlyCompiledIR(orc, r_mod, mod, resolver, ctx)
90+
API.LLVMOrcAddEagerlyCompiledIR(orc, r_mod, mod, resolver, resolver_ctx)
5191
end
5292
OrcModule(r_mod[])
5393
end
@@ -56,9 +96,9 @@ function Base.delete!(orc::OrcJIT, mod::OrcModule)
5696
LLVM.API.LLVMOrcRemoveModule(orc, mod)
5797
end
5898

59-
function add!(orc::OrcJIT, obj::MemoryBuffer, resolver = C_NULL, ctx = C_NULL)
99+
function add!(orc::OrcJIT, obj::MemoryBuffer, resolver = @cfunction(resolver, UInt64, (Cstring, Ptr{Cvoid})), resolver_ctx = orc)
60100
r_mod = Ref{API.LLVMOrcModuleHandle}()
61-
API.LLVMOrcAddObjectFile(orc, r_mod, obj, resolver, ctx)
101+
API.LLVMOrcAddObjectFile(orc, r_mod, obj, resolver, resolver_ctx)
62102
return OrcModule(r_mod[])
63103
end
64104

test/orc.jl

Lines changed: 144 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,44 @@
11
@testset "orc" begin
22

3-
let ctx = Context()
3+
@testset "Undefined Symbol" begin
4+
ctx = Context()
5+
tm = JITTargetMachine()
6+
orc = OrcJIT(tm)
7+
8+
mod = LLVM.Module("jit", ctx)
9+
T_Int32 = LLVM.Int32Type(ctx)
10+
ft = LLVM.FunctionType(T_Int32, [T_Int32, T_Int32])
11+
fn = LLVM.Function(mod, "mysum", ft)
12+
linkage!(fn, LLVM.API.LLVMExternalLinkage)
13+
14+
fname = mangle(orc, "wrapper")
15+
wrapper = LLVM.Function(mod, fname, ft)
16+
# generate IR
17+
Builder(ctx) do builder
18+
entry = BasicBlock(wrapper, "entry", ctx)
19+
position!(builder, entry)
20+
21+
tmp = call!(builder, fn, [parameters(wrapper)...])
22+
ret!(builder, tmp)
23+
end
24+
25+
triple!(mod, triple(tm))
26+
ModulePassManager() do pm
27+
add_library_info!(pm, triple(mod))
28+
add_transform_info!(pm, tm)
29+
run!(pm, mod)
30+
end
31+
verify(mod)
32+
33+
orc_mod = compile!(orc, mod)
34+
@test_throws ErrorException address(orc, fname)
35+
36+
delete!(orc, orc_mod)
37+
dispose(orc)
38+
end
39+
40+
@testset "Custom Resolver" begin
41+
ctx = Context()
442
tm = JITTargetMachine()
543
orc = OrcJIT(tm)
644

@@ -14,10 +52,10 @@ let ctx = Context()
1452
end
1553
fnames[name] += 1
1654

17-
return get(known_functions, name, OrcTargetAddress(C_NULL)).ptr
55+
return known_functions[name].ptr
1856
catch ex
19-
@error "Exception during lookup" exception=(ex, catch_backtrace())
20-
return UInt64(0)
57+
@error "Exception during lookup" name exception=(ex, catch_backtrace())
58+
error("OrcJIT: Could not find symbol")
2159
end
2260
end
2361

@@ -51,7 +89,7 @@ let ctx = Context()
5189

5290
f_lookup = @cfunction($lookup, UInt64, (Cstring, Ptr{Cvoid}))
5391
GC.@preserve f_lookup begin
54-
orc_mod = compile!(orc, mod, f_lookup, lazy=true) # will capture f_lookup
92+
orc_mod = compile!(orc, mod, f_lookup, C_NULL, lazy=true) # will capture f_lookup
5593

5694
addr = address(orc, fname)
5795
@test errormsg(orc) == ""
@@ -73,7 +111,101 @@ let ctx = Context()
73111
dispose(orc)
74112
end
75113

76-
let ctx = Context()
114+
@testset "Default Resolver + Stub" begin
115+
ctx = Context()
116+
tm = JITTargetMachine()
117+
orc = OrcJIT(tm)
118+
119+
mod = LLVM.Module("jit", ctx)
120+
T_Int32 = LLVM.Int32Type(ctx)
121+
ft = LLVM.FunctionType(T_Int32, [T_Int32, T_Int32])
122+
fn = LLVM.Function(mod, "mysum", ft)
123+
linkage!(fn, LLVM.API.LLVMExternalLinkage)
124+
125+
fname = mangle(orc, "wrapper")
126+
wrapper = LLVM.Function(mod, fname, ft)
127+
# generate IR
128+
Builder(ctx) do builder
129+
entry = BasicBlock(wrapper, "entry", ctx)
130+
position!(builder, entry)
131+
132+
tmp = call!(builder, fn, [parameters(wrapper)...])
133+
ret!(builder, tmp)
134+
end
135+
136+
triple!(mod, triple(tm))
137+
ModulePassManager() do pm
138+
add_library_info!(pm, triple(mod))
139+
add_transform_info!(pm, tm)
140+
run!(pm, mod)
141+
end
142+
verify(mod)
143+
144+
create_stub!(orc, mangle(orc, "mysum"), OrcTargetAddress(@cfunction(+, Int32, (Int32, Int32))))
145+
146+
orc_mod = compile!(orc, mod)
147+
148+
addr = address(orc, fname)
149+
@test errormsg(orc) == ""
150+
151+
r = ccall(pointer(addr), Int32, (Int32, Int32), 1, 2)
152+
@test r == 3
153+
154+
delete!(orc, orc_mod)
155+
dispose(orc)
156+
end
157+
158+
@testset "Default Resolver + Global Symbol" begin
159+
ctx = Context()
160+
tm = JITTargetMachine()
161+
orc = OrcJIT(tm)
162+
163+
mod = LLVM.Module("jit", ctx)
164+
T_Int32 = LLVM.Int32Type(ctx)
165+
ft = LLVM.FunctionType(T_Int32, [T_Int32, T_Int32])
166+
mysum = mangle(orc, "mysum")
167+
fn = LLVM.Function(mod, mysum, ft)
168+
linkage!(fn, LLVM.API.LLVMExternalLinkage)
169+
170+
fname = mangle(orc, "wrapper")
171+
wrapper = LLVM.Function(mod, fname, ft)
172+
# generate IR
173+
Builder(ctx) do builder
174+
entry = BasicBlock(wrapper, "entry", ctx)
175+
position!(builder, entry)
176+
177+
tmp = call!(builder, fn, [parameters(wrapper)...])
178+
ret!(builder, tmp)
179+
end
180+
181+
triple!(mod, triple(tm))
182+
ModulePassManager() do pm
183+
add_library_info!(pm, triple(mod))
184+
add_transform_info!(pm, tm)
185+
run!(pm, mod)
186+
end
187+
verify(mod)
188+
189+
# Should do pretty much the same as `@ccallable`
190+
LLVM.add_symbol(mysum, @cfunction(+, Int32, (Int32, Int32)))
191+
ptr = LLVM.find_symbol(mysum)
192+
@test ptr !== C_NULL
193+
@test ccall(ptr, Int32, (Int32, Int32), 1, 2) == 3
194+
195+
orc_mod = compile!(orc, mod, lazy=true)
196+
197+
addr = address(orc, fname)
198+
@test errormsg(orc) == ""
199+
200+
r = ccall(pointer(addr), Int32, (Int32, Int32), 1, 2)
201+
@test r == 3
202+
203+
delete!(orc, orc_mod)
204+
dispose(orc)
205+
end
206+
207+
@testset "Loading ObjectFile" begin
208+
ctx = Context()
77209
tm = JITTargetMachine()
78210
orc = OrcJIT(tm)
79211
sym = mangle(orc, "SomeFunction")
@@ -96,6 +228,12 @@ let ctx = Context()
96228

97229
@test addr.ptr != 0
98230
delete!(orc, orc_m)
231+
end
232+
233+
@testset "Stubs" begin
234+
ctx = Context()
235+
tm = JITTargetMachine()
236+
orc = OrcJIT(tm)
99237

100238
toggle = Ref{Bool}(false)
101239
on() = (toggle[] = true; nothing)

0 commit comments

Comments
 (0)