Skip to content

Commit c4801a6

Browse files
Merge pull request #12 from c42f/cjf/global-lookup-fix
Ensure globals are looked up in the user's module
2 parents bbfaf60 + 3ab1da5 commit c4801a6

File tree

4 files changed

+58
-4
lines changed

4 files changed

+58
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RuntimeGeneratedFunctions"
22
uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com> and contributors"]
4-
version = "0.3.2"
4+
version = "0.4.0"
55

66
[deps]
77
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"

src/RuntimeGeneratedFunctions.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ then calling the resulting function. The differences are:
3232
* The result is not a named generic function, and doesn't participate in
3333
generic function dispatch; it's more like a callable method.
3434
35+
You need to use `RuntimeGeneratedFunctions.init(your_module)` a single time at
36+
the top level of `your_module` before any other uses of the macro.
37+
3538
# Examples
3639
```
40+
RuntimeGeneratedFunctions.init(@__MODULE__) # Required at module top-level
41+
3742
function foo()
3843
expression = :((x,y)->x+y+1) # May be generated dynamically
3944
f = @RuntimeGeneratedFunction(expression)
@@ -42,8 +47,11 @@ end
4247
```
4348
"""
4449
macro RuntimeGeneratedFunction(ex)
45-
_ensure_cache_exists!(__module__)
4650
quote
51+
if !($(esc(:(@isdefined($_tagname)))))
52+
error("""You must use `RuntimeGeneratedFunctions.init(@__MODULE__)` at module
53+
top level before using runtime generated functions""")
54+
end
4755
RuntimeGeneratedFunction(
4856
$(esc(_tagname)),
4957
$(esc(ex))
@@ -59,7 +67,11 @@ end
5967

6068
(f::RuntimeGeneratedFunction)(args::Vararg{Any,N}) where N = generated_callfunc(f, args...)
6169

62-
@inline @generated function generated_callfunc(f::RuntimeGeneratedFunction{moduletag, id, argnames}, __args...) where {moduletag,id,argnames}
70+
# We'll generate a method of this function in every module which wants to use
71+
# @RuntimeGeneratedFunction
72+
function generated_callfunc end
73+
74+
function generated_callfunc_body(moduletag, id, argnames, __args)
6375
setup = (:($(argnames[i]) = @inbounds __args[$i]) for i in 1:length(argnames))
6476
body = _lookup_body(moduletag, id)
6577
@assert body !== nothing
@@ -122,13 +134,34 @@ function _lookup_body(moduletag, id)
122134
end
123135
end
124136

125-
function _ensure_cache_exists!(mod)
137+
"""
138+
RuntimeGeneratedFunctions.init(mod)
139+
140+
Use this at top level to set up your module `mod` before using
141+
`@RuntimeGeneratedFunction`.
142+
"""
143+
function init(mod)
126144
lock(_cache_lock) do
127145
if !isdefined(mod, _cachename)
128146
mod.eval(quote
129147
const $_cachename = Dict()
130148
struct $_tagname
131149
end
150+
151+
# We create method of `generated_callfunc` in the user's module
152+
# so that any global symbols within the body will be looked up
153+
# in the user's module scope.
154+
#
155+
# This is straightforward but clunky. A neater solution should
156+
# be to explicitly expand in the user's module and return a
157+
# CodeInfo from `generated_callfunc`, but it seems we'd need
158+
# `jl_expand_and_resolve` which doesn't exist until Julia 1.3
159+
# or so. See:
160+
# https://github.com/JuliaLang/julia/pull/32902
161+
# https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L30
162+
@inline @generated function $RuntimeGeneratedFunctions.generated_callfunc(f::$RuntimeGeneratedFunctions.RuntimeGeneratedFunction{$_tagname, id, argnames}, __args...) where {id,argnames}
163+
$RuntimeGeneratedFunctions.generated_callfunc_body($_tagname, id, argnames, __args)
164+
end
132165
end)
133166
end
134167
end

test/precomp/RGFPrecompTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module RGFPrecompTest
22
using RuntimeGeneratedFunctions
3+
RuntimeGeneratedFunctions.init(@__MODULE__)
34

45
f = @RuntimeGeneratedFunction(:((x,y)->x+y))
56
end

test/runtests.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using RuntimeGeneratedFunctions, BenchmarkTools
22
using Test
33

4+
RuntimeGeneratedFunctions.init(@__MODULE__)
5+
46
function f(_du,_u,_p,_t)
57
@inbounds _du[1] = _u[1]
68
@inbounds _du[2] = _u[2]
@@ -107,3 +109,21 @@ for k=1:4
107109
end
108110
@test all(all.(fetch.(tasks)))
109111

112+
113+
# Test that globals are resolved within the correct scope
114+
115+
module GlobalsTest
116+
using RuntimeGeneratedFunctions
117+
RuntimeGeneratedFunctions.init(@__MODULE__)
118+
119+
y = 40
120+
f = @RuntimeGeneratedFunction(:(x->x+y))
121+
end
122+
123+
@test GlobalsTest.f(2) == 42
124+
125+
@test_throws ErrorException @eval(module NotInitTest
126+
using RuntimeGeneratedFunctions
127+
# RuntimeGeneratedFunctions.init(@__MODULE__) # <-- missing
128+
f = @RuntimeGeneratedFunction(:(x->x+y))
129+
end)

0 commit comments

Comments
 (0)