Skip to content

Commit 1b5accf

Browse files
committed
Fix cache leaks for pe._cached_abstract_eval
The function `pe._cached_abstract_eval` uses `util.cache`, while most other cached functions in JAX use `weakref_lru_cache`. The main difference is that `util.cache` keeps strong references to the function arguments. The modified test `lax_control_flow_test::test_cond_memory_leak` (added a jit for one of the branches) is failing without this fix. This is because the Jaxpr including the closed-over constant leaks due to the strong references kept by the `util.cache` used in `pe._cached_abstract_eval`. We cannot use directly `weakref_lru_cache` because it keeps weak references only to the first positional argument. We add here a variant `multi_weakref_lru_cache` that uses weak references for all the positional and keyword arguments for which `util.is_weakref_cache_key_type` is true. This is set for `Jaxpr`, `ClosedJaxpr` and `Callable`. Eventually, we can decide to generalize the existing `weakref_lru_cache` to have this behavior.
1 parent 2c5bfac commit 1b5accf

File tree

6 files changed

+257
-13
lines changed

6 files changed

+257
-13
lines changed

jax/_src/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
5252
tuple_delete, cache,
5353
HashableFunction, HashableWrapper, weakref_lru_cache,
54-
partition_list, StrictABCMeta, foreach)
54+
partition_list, StrictABCMeta, foreach,
55+
weakref_cache_key_types)
5556
import jax._src.pretty_printer as pp
5657
from jax._src.named_sharding import NamedSharding
5758
from jax._src.lib import jax_jit
@@ -196,6 +197,7 @@ def replace(self, **kwargs):
196197
raise ValueError(f"Unknown keyword arguments: {kwargs}")
197198
return jaxpr
198199

200+
weakref_cache_key_types.add(Jaxpr)
199201

200202
def join_effects(*effects: Effects) -> Effects:
201203
return set().union(*effects) if effects else no_effects
@@ -288,6 +290,9 @@ def pretty_print(self, *, source_info=False, print_shapes=True,
288290
def _repr_pretty_(self, p, cycle):
289291
return p.text(self.pretty_print(use_color=True))
290292

293+
weakref_cache_key_types.add(ClosedJaxpr)
294+
295+
291296
@curry
292297
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
293298
# TODO(dougalm): remove this hack when we add contexts to jaxpr.

jax/_src/dispatch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def apply_primitive(prim, *args, **params):
9393
lib.jax_jit.swap_thread_local_state_disable_jit(prev)
9494
return outs
9595

96+
# TODO(necula): this cache will contain strong references to all
97+
# Jaxprs on which it is used. This is not immediately fixable by using
98+
# util.multi_weakref_lru_cache, because the `params` (including the Jaxpr)
99+
# are closed over in the `prim_fun` lambda. Leaving this fix for a later PR.
96100
@util.cache()
97101
def xla_primitive_callable(prim: core.Primitive, **params):
98102
util.test_event("xla_primitive_callable_cache_miss")

jax/_src/interpreters/partial_eval.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@
4747
from jax._src.tree_util import PyTreeDef, treedef_tuple, register_static
4848
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
4949
merge_lists, partition_list, OrderedSet,
50-
as_hashable_function, weakref_lru_cache, subs_list,
51-
HashableFunction, foreach, cache)
50+
as_hashable_function, weakref_lru_cache,
51+
multi_weakref_lru_cache, subs_list,
52+
HashableFunction, foreach)
5253

5354

5455
map, unsafe_map = safe_map, map
@@ -1627,7 +1628,7 @@ def move_outvars_to_back(jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJ
16271628
return _move_outvars_to_back(jaxpr, tuple(to_move))
16281629

16291630
@weakref_lru_cache
1630-
def _move_outvars_to_back(jaxpr, to_move):
1631+
def _move_outvars_to_back(jaxpr: core.ClosedJaxpr, to_move):
16311632
new_outvars = ([e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if not m] +
16321633
[e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if m])
16331634
return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars))
@@ -1879,7 +1880,7 @@ def vars(atom: Atom) -> list[Var]:
18791880
return constvars, constvals
18801881

18811882

1882-
@cache()
1883+
@multi_weakref_lru_cache
18831884
def _cached_abstract_eval(primitive: core.Primitive, *aval_qdds, **params):
18841885
return primitive.abstract_eval(*aval_qdds, **params)
18851886

@@ -2729,7 +2730,7 @@ def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, lis
27292730

27302731

27312732
@weakref_lru_cache
2732-
def lower_jaxpr(hi_jaxpr):
2733+
def lower_jaxpr(hi_jaxpr: core.ClosedJaxpr):
27332734
lo_avals = [lo_ty for aval in hi_jaxpr.in_aval_qdds for lo_ty in aval.lo_ty()]
27342735
f = lu.wrap_init(partial(lower_traceable, hi_jaxpr),
27352736
debug_info=hi_jaxpr.jaxpr.debug_info)

jax/_src/util.py

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
import abc
1818
from collections.abc import Callable, Iterable, Iterator, Sequence
19+
import dataclasses
1920
import functools
2021
from functools import partial
2122
import itertools as it
2223
import logging
2324
import math
2425
import operator
25-
from typing import (Any, Generic, SupportsIndex, TypeVar, overload, TYPE_CHECKING, cast)
26+
from typing import (Any, Generic, SupportsIndex, Type, TypeVar, overload, TYPE_CHECKING, cast)
2627
import weakref
2728

2829
import numpy as np
@@ -331,8 +332,8 @@ def weakref_lru_cache(call: Callable, maxsize=2048,
331332
Least recently used cache decorator with weakref support.
332333
333334
The cache will take a weakref to the first argument of the wrapped function
334-
and strong refs to all subsequent operations. In all other respects it should
335-
behave similar to `functools.lru_cache`.
335+
and strong refs to all other arguments. In all other respects it should
336+
behave similar to `functools.lru_cache`. The cache is thread local.
336337
"""
337338
cached_call = _weakref_lru_cache.weakref_lru_cache(
338339
config.trace_context if trace_context_in_key else _ignore, call, maxsize
@@ -341,6 +342,132 @@ def weakref_lru_cache(call: Callable, maxsize=2048,
341342
return cached_call
342343

343344

345+
@dataclasses.dataclass(frozen=True, slots=True, weakref_slot=True)
346+
class MultiWeakRefCacheKey:
347+
weakrefs: tuple[weakref.ref, ...] # Used only when len(weakrefs) >= 2
348+
349+
350+
class MultiWeakRefPlaceholder:
351+
# Stands for an arg/kwarg that was replaced with a weakref
352+
pass
353+
_multi_weakref_placeholder = MultiWeakRefPlaceholder()
354+
355+
# The types of arguments for which `multi_weakref_lru_cache` should keep
356+
# weak references.
357+
weakref_cache_key_types: set[Type] = set()
358+
def is_weakref_cache_key_type(v):
359+
return callable(v) or (type(v) in weakref_cache_key_types)
360+
361+
362+
def multi_weakref_lru_cache(
363+
call: Callable, *,
364+
maxsize=2048,
365+
trace_context_in_key: bool = True):
366+
"""
367+
Least recently used cache decorator with weakref support.
368+
369+
Similar to `weakref_lru_cache`, except that it keeps weak references
370+
to all positional and keyword arguments for which
371+
`is_weakref_cache_key_type()` is true, and strong references to
372+
other arguments. The cache entry is removed if any of the weakref
373+
arguments dies.
374+
"""
375+
# Keep strong references to the MultiWeakRefCacheKeys that resulted in
376+
# cache misses, and are cache keys. Indexed by id. Only keys with all
377+
# included weakrefs live are present.
378+
id_to_key: dict[int, MultiWeakRefCacheKey] = {}
379+
# For each `wr: weakref.ref` present in `key: MultiWeakRefCacheKey` we have
380+
# `id(key) in weakref_to_key_ids[wr]`.
381+
weakref_to_key_ids: dict[weakref.ref, set[int]] = {}
382+
383+
def remove_weakref(wr: weakref.ref):
384+
key_ids = weakref_to_key_ids.get(wr, set())
385+
for key_id in key_ids:
386+
try:
387+
del id_to_key[key_id]
388+
except KeyError:
389+
pass
390+
try:
391+
del weakref_to_key_ids[wr]
392+
except KeyError:
393+
pass
394+
395+
def weakrefs_to_sentinel(v, acc: list[Any]):
396+
if isinstance(v, tuple):
397+
return tuple(weakrefs_to_sentinel(v1, acc) for v1 in v)
398+
elif isinstance(v, dict):
399+
return {k: weakrefs_to_sentinel(v1, acc) for k, v1 in v.items()}
400+
elif is_weakref_cache_key_type(v):
401+
acc.append(v)
402+
return _multi_weakref_placeholder
403+
else:
404+
return v
405+
406+
def sentinel_to_referrents(v,
407+
it: Iterator[weakref.ref],
408+
key_id: int | None):
409+
# key_id is not None iff we use a MultiWeakRefCacheKey (>= 2 weakrefs)
410+
if isinstance(v, tuple):
411+
return tuple(sentinel_to_referrents(v1, it, key_id) for v1 in v)
412+
elif isinstance(v, dict):
413+
return {k: sentinel_to_referrents(v1, it, key_id)
414+
for k, v1 in v.items()}
415+
elif v is _multi_weakref_placeholder:
416+
wr = next(it)
417+
if key_id is not None:
418+
weakref_to_key_ids.setdefault(wr, set()).add(key_id)
419+
return wr()
420+
else:
421+
return v
422+
423+
def cache_miss(key: MultiWeakRefCacheKey | MultiWeakRefPlaceholder | Any,
424+
*args, **kwargs):
425+
if isinstance(key, MultiWeakRefCacheKey): # had at least 2 weakrefs
426+
# We know `key` is in `cached_call` cache, so store strong references
427+
key_id = id(key)
428+
id_to_key[key_id] = key
429+
orig_args, orig_kwargs = sentinel_to_referrents(
430+
(args, kwargs), iter(key.weakrefs), key_id)
431+
elif key is _multi_weakref_placeholder: # had 0 weakrefs
432+
orig_args = args
433+
orig_kwargs = kwargs
434+
else: # had 1 weakref, we had put it first as the `key`
435+
orig_args, orig_kwargs = sentinel_to_referrents(
436+
(args, kwargs), iter([weakref.ref(key)]), None)
437+
return call(*orig_args, **orig_kwargs)
438+
439+
440+
cached_call = _weakref_lru_cache.weakref_lru_cache(
441+
config.trace_context if trace_context_in_key else _ignore,
442+
cache_miss, maxsize
443+
)
444+
register_cache(cached_call, str(call))
445+
446+
@functools.wraps(call)
447+
def wrapper(*args, **kwargs):
448+
acc_weakrefs: list[Any] = []
449+
args, kwargs = weakrefs_to_sentinel((args, kwargs), acc_weakrefs)
450+
nr_weakrefs = len(acc_weakrefs)
451+
if nr_weakrefs == 0:
452+
return cached_call(_multi_weakref_placeholder, *args, **kwargs)
453+
elif nr_weakrefs == 1:
454+
# Put the single weakref first, and skip the MultiWeakRefCacheKey
455+
return cached_call(acc_weakrefs[0], *args, **kwargs)
456+
else:
457+
value_to_weakref = {v: weakref.ref(v, remove_weakref)
458+
for v in set(acc_weakrefs)}
459+
key = MultiWeakRefCacheKey(weakrefs=tuple(value_to_weakref[v]
460+
for v in acc_weakrefs))
461+
return cached_call(key, *args, **kwargs)
462+
463+
wrapper.cache_info = cached_call.cache_info
464+
wrapper.cache_clear = cached_call.cache_clear
465+
wrapper.cache_keys = cached_call.cache_keys
466+
wrapper._multi_weakref_id_to_key = id_to_key # stays alive as long as wrapper
467+
wrapper._multi_weakref_to_key_ids = weakref_to_key_ids
468+
return wrapper
469+
470+
344471
class Unhashable:
345472
__slots__ = ["val"]
346473

tests/lax_control_flow_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import collections
1717
import contextlib
18+
import gc
1819
from functools import partial
1920
import itertools
2021
import math
@@ -3162,18 +3163,21 @@ def test_cond_casting(self):
31623163
@jtu.thread_unsafe_test() # live_arrays count isn't thread-safe
31633164
def test_cond_memory_leak(self):
31643165
# https://github.com/jax-ml/jax/issues/12719
3165-
31663166
def leak():
31673167
data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1)
31683168
def g():
31693169
return jax.lax.cond(
31703170
True,
3171-
lambda: data[0], # noqa: F821
3171+
jax.jit(lambda: data[0]), # noqa: F821
31723172
lambda: data[1], # noqa: F821
31733173
)
3174+
# _ = g() # TODO(necula): enable this, requires fixing leaks in the
3175+
# caching of dispatch.xla_primitive_callable.
31743176
jg = jax.jit(g)
31753177
_ = jg().block_until_ready()
3178+
jg.clear_cache()
31763179
del g, jg, data, _
3180+
gc.collect()
31773181

31783182
nbufs = lambda: len(jax.live_arrays())
31793183
base = nbufs()

tests/util_test.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import dataclasses
15+
import random
1516
from functools import partial
17+
import gc
1618
import operator
19+
import threading
1720

18-
from absl.testing import absltest
21+
from absl.testing import absltest, parameterized
1922
import jax
2023
from jax import api_util
2124
from jax._src import linear_util as lu
@@ -155,6 +158,106 @@ def reference_loop_generator(x):
155158
for _ in range(4097):
156159
reference_loop_generator(lambda x: x)
157160

161+
@parameterized.named_parameters(
162+
dict(weakref_count=weakref_count,
163+
testcase_name=f"_{weakref_count=}")
164+
for weakref_count in [0, 1, 3])
165+
def test_multi_weak_ref_cache(self, *, weakref_count=0):
166+
167+
class Key: # hashed by id
168+
def __init__(self, x):
169+
self.x = x
170+
171+
if weakref_count > 0:
172+
util.weakref_cache_key_types.add(Key)
173+
174+
@partial(util.multi_weakref_lru_cache, trace_context_in_key=False)
175+
def myfun(a, k1, *, k2, k3):
176+
return f"{a=}, {k1=}, {k2=}, {k3=}"
177+
178+
def check_invariant(expected_live_keys: int):
179+
self.assertLen(myfun._multi_weakref_id_to_key, expected_live_keys if weakref_count > 1 else 0)
180+
for key_id, key in myfun._multi_weakref_id_to_key.items():
181+
for wr in key.weakrefs:
182+
self.assertIn(wr, myfun._multi_weakref_to_key_ids)
183+
self.assertIn(key_id, myfun._multi_weakref_to_key_ids[wr])
184+
185+
k1 = Key(1)
186+
k3 = (k1, k1) if weakref_count > 1 else 4
187+
util.clear_all_caches()
188+
r1 = myfun(2, k1, k2=3, k3=k3) # miss
189+
c1 = myfun.cache_info()
190+
self.assertEqual((0, 1, 1), (c1.hits, c1.misses, c1.currsize))
191+
check_invariant(1)
192+
193+
for i in range(10):
194+
r2 = myfun(2, k1, k2=3, k3=k3) # all hits
195+
self.assertIs(r1, r2)
196+
c2 = myfun.cache_info()
197+
self.assertEqual((1 + i, 1, 1), (c2.hits, c2.misses, c2.currsize))
198+
check_invariant(1)
199+
200+
del k1, k3 # expect that the cache entries are removed (if weakref_count > 0)
201+
gc.collect()
202+
c3 = myfun.cache_info()
203+
self.assertEqual(c3.currsize, 0 if weakref_count > 0 else 1)
204+
check_invariant(0)
205+
206+
k1_2 = Key(2)
207+
k3_2 = (Key(3), Key(3)) if weakref_count > 1 else (3, 3)
208+
r4 = myfun(2, k1_2, k2=3, k3=k3_2) # miss
209+
c4 = myfun.cache_info()
210+
self.assertEqual((10, 2, (1 if weakref_count > 0 else 2)), (c4.hits, c4.misses, c4.currsize))
211+
check_invariant(1)
212+
213+
if weakref_count > 1:
214+
del k3_2 # clear the cache entry
215+
gc.collect()
216+
c5 = myfun.cache_info()
217+
self.assertEqual((10, 2, 0), (c5.hits, c5.misses, c5.currsize))
218+
check_invariant(0)
219+
220+
k3_3 = (Key(3), Key(3))
221+
r6 = myfun(2, k1_2, k2=3, k3=k3_3) # miss because Key hashed by it
222+
self.assertIsNot(r4, r6)
223+
c6 = myfun.cache_info()
224+
self.assertEqual((10, 3, 1), (c6.hits, c6.misses, c6.currsize))
225+
check_invariant(1)
226+
227+
del k1_2
228+
gc.collect()
229+
c7 = myfun.cache_info()
230+
self.assertEqual(0 if weakref_count > 0 else 2, c7.currsize )
231+
check_invariant(0)
232+
233+
def test_multi_weakref_lru_cache_threads(self):
234+
num_workers = 5
235+
num_live_keys_per_worker = 16
236+
size_key_space = 32
237+
@dataclasses.dataclass(frozen=True)
238+
class WRKey:
239+
f: int
240+
241+
util.weakref_cache_key_types.add(WRKey)
242+
243+
@partial(util.multi_weakref_lru_cache, maxsize=size_key_space // 2)
244+
def myfun(k: WRKey):
245+
return None
246+
247+
def Worker():
248+
keys = [None] * num_live_keys_per_worker # These are the live keys for this worker
249+
for i in range(1000):
250+
key_idx = random.randint(0, num_live_keys_per_worker - 1)
251+
key = WRKey(random.randint(0, size_key_space))
252+
myfun(key)
253+
keys[key_idx] = key # Kill some previous key and keep this live
254+
255+
workers = [threading.Thread(target=Worker()) for _ in range(num_workers)]
256+
for t in workers:
257+
t.start()
258+
for t in workers:
259+
t.join()
260+
158261

159262
class SafeMapTest(jtu.JaxTestCase):
160263

0 commit comments

Comments
 (0)