Fix cache leaks for pe._cached_abstract_eval; add util.multi_weakref_lru_cache #30009
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The function
pe._cached_abstract_eval
usesutil.cache
, while most other cached functions in JAX useweakref_lru_cache
. The main difference is thatutil.cache
keeps strong references to the function arguments.The modified test
lax_control_flow_test::test_cond_memory_leak
(added a jit for one of the branches) is failing without this fix. This is because the Jaxpr including the closed-over constant leaks due to the strong references to**params
kept by theutil.cache
used inpe._cached_abstract_eval
.We cannot use directly
weakref_lru_cache
because it keeps weak references only to the first positional argument. We add here a variantmulti_weakref_lru_cache
that uses weak references for all the positional and keyword arguments for whichutil.is_weakref_cache_key_type
is true. This is currently set forJaxpr
,ClosedJaxpr
andCallable
.The
multi_weakref_lru_cache
is a wrapper aroundweakref_lru_cache
, and in fact if there is exactly oneweakref, it behaves exactly like
weakref_lru_cache
. Eventually, we can decide to generalize the existingweakref_lru_cache
to have this behavior.