@@ -75,8 +75,19 @@ struct RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id, B} <: Func
75
75
end
76
76
end
77
77
78
- function drop_expr (:: RuntimeGeneratedFunction{A, C1, C2, ID} ) where {A, C1, C2, ID}
79
- RuntimeGeneratedFunction {A, C1, C2, ID} (nothing )
78
+ function drop_expr (:: RuntimeGeneratedFunction{a, cache_tag, c, id} ) where {a, cache_tag, c,
79
+ id}
80
+ # When dropping the reference to the body from an RGF, we need to upgrade
81
+ # from a weak to a strong reference in the cache to prevent the body being
82
+ # GC'd.
83
+ lock (_cache_lock) do
84
+ cache = getfield (parentmodule (cache_tag), _cachename)
85
+ body = cache[id]
86
+ if body isa WeakRef
87
+ cache[id] = body. value
88
+ end
89
+ end
90
+ RuntimeGeneratedFunction {a, cache_tag, c, id} (nothing )
80
91
end
81
92
82
93
function _check_rgf_initialized (mods... )
@@ -119,7 +130,7 @@ function Base.show(io::IO, ::MIME"text/plain",
119
130
}
120
131
cache_mod = parentmodule (cache_tag)
121
132
context_mod = parentmodule (context_tag)
122
- func_expr = Expr (:-> , Expr (:tuple , argnames... ), f . body )
133
+ func_expr = Expr (:-> , Expr (:tuple , argnames... ), _lookup_body (cache_tag, id) )
123
134
print (io, " RuntimeGeneratedFunction(#=in $cache_mod =#, #=using $context_mod =#, " ,
124
135
repr (func_expr), " )" )
125
136
end
@@ -169,24 +180,38 @@ function _cache_body(cache_tag, id, body)
169
180
cache = getfield (parentmodule (cache_tag), _cachename)
170
181
# Caching is tricky when `id` is the same for different AST instances:
171
182
#
172
- # Tricky case #1: If a function body with the same `id` was cached
173
- # previously, we need to use that older instance of the body AST as the
174
- # canonical one rather than `body`. This ensures the lifetime of the
175
- # body in the cache will always cover the lifetime of the parent
176
- # `RuntimeGeneratedFunction`s when they share the same `id`.
177
- cached_body = haskey (cache, id) ? cache[id] : nothing
178
- cached_body = cached_body != = nothing ? cached_body : body
179
- # We cannot use WeakRef because we might drop body to make RGF GPU
180
- # compatible.
181
- cache[id] = cached_body
183
+ # 1. If a function body with the same `id` was cached previously, we need
184
+ # to use that older instance of the body AST as the canonical one
185
+ # rather than `body`. This ensures the lifetime of the body in the
186
+ # cache will always cover the lifetime of all RGFs which share the same
187
+ # `id`.
188
+ #
189
+ # 2. Unless we hold a separate reference to `cache[id].value`, the GC
190
+ # can collect it (causing it to become `nothing`). So root it in a
191
+ # local variable first.
192
+ #
193
+ cached_body = get (cache, id, nothing )
194
+ if ! isnothing (cached_body)
195
+ if cached_body isa WeakRef
196
+ # `value` may be nothing here if it was previously cached but GC'd
197
+ cached_body = cached_body. value
198
+ end
199
+ end
200
+ if isnothing (cached_body)
201
+ cached_body = body
202
+ # Use a WeakRef to allow `body` to be garbage collected. (After GC, the
203
+ # cache will still contain an empty entry with key `id`.)
204
+ cache[id] = WeakRef (cached_body)
205
+ end
182
206
return cached_body
183
207
end
184
208
end
185
209
186
210
function _lookup_body (cache_tag, id)
187
211
lock (_cache_lock) do
188
212
cache = getfield (parentmodule (cache_tag), _cachename)
189
- cache[id]
213
+ body = cache[id]
214
+ body isa WeakRef ? body. value : body
190
215
end
191
216
end
192
217
0 commit comments