Skip to content

Commit b2867cc

Browse files
Merge pull request #60 from SciML/myb/generic_body
Make body more generic
2 parents 557a861 + 9b02f89 commit b2867cc

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

src/RuntimeGeneratedFunctions.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ end
5050
"""
5151

5252
"$_rgf_docs"
53-
struct RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id} <: Function
54-
body::Expr
53+
struct RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id, B} <: Function
54+
body::B
5555
function RuntimeGeneratedFunction(cache_tag, context_tag, ex; opaque_closures = true)
5656
def = splitdef(ex)
5757
args, body = normalize_args(def[:args]), def[:body]
@@ -61,20 +61,22 @@ struct RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id} <: Functio
6161
end
6262
id = expr_to_id(body)
6363
cached_body = _cache_body(cache_tag, id, body)
64-
new{Tuple(args), cache_tag, context_tag, id}(cached_body)
64+
new{Tuple(args), cache_tag, context_tag, id, typeof(cached_body)}(cached_body)
6565
end
6666

6767
# For internal use in deserialize() - doesen't check whether the body is in the cache!
68-
function RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id}(body::Expr) where {
68+
function RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id}(body) where {
6969
argnames,
7070
cache_tag,
7171
context_tag,
72-
id
72+
id,
7373
}
74-
new{argnames, cache_tag, context_tag, id}(body)
74+
new{argnames, cache_tag, context_tag, id, typeof(body)}(body)
7575
end
7676
end
7777

78+
drop_expr(::RuntimeGeneratedFunction{A, C1, C2, ID}) where {A, C1, C2, ID} = RuntimeGeneratedFunction{A, C1, C2, ID}(nothing)
79+
7880
function _check_rgf_initialized(mods...)
7981
for mod in mods
8082
if !isdefined(mod, _tagname)
@@ -298,8 +300,7 @@ end
298300
# We write an explicit deserialize() here to trigger caching of the body on a
299301
# remote node when using Serialialization.jl (in Distributed.jl and elsewhere)
300302
function Serialization.deserialize(s::AbstractSerializer,
301-
::Type{
302-
RuntimeGeneratedFunction{argnames, cache_tag,
303+
::Type{<:RuntimeGeneratedFunction{argnames, cache_tag,
303304
context_tag, id}}) where {
304305
argnames,
305306
cache_tag,

0 commit comments

Comments
 (0)