Skip to content

Fix cache leaks for pe._cached_abstract_eval; add util.multi_weakref_lru_cache #30009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jul 7, 2025

The function pe._cached_abstract_eval uses util.cache, while most other cached functions in JAX use weakref_lru_cache. The main difference is that util.cache keeps strong references to the function arguments.

The modified test lax_control_flow_test::test_cond_memory_leak (added a jit for one of the branches) is failing without this fix. This is because the Jaxpr including the closed-over constant leaks due to the strong references to **params kept by the util.cache used in pe._cached_abstract_eval.

We cannot use directly weakref_lru_cache because it keeps weak references only to the first positional argument. We add here a variant multi_weakref_lru_cache that uses weak references for all the positional and keyword arguments for which util.is_weakref_cache_key_type is true. This is currently set for Jaxpr, ClosedJaxpr and Callable.

The multi_weakref_lru_cache is a wrapper around weakref_lru_cache, and in fact if there is exactly one
weakref, it behaves exactly like weakref_lru_cache. Eventually, we can decide to generalize the existing weakref_lru_cache to have this behavior.

@gnecula gnecula self-assigned this Jul 7, 2025
@gnecula gnecula force-pushed the weak_lu_cache branch 5 times, most recently from b8db4fb to 1b5accf Compare July 8, 2025 01:40
@gnecula gnecula requested a review from pschuh July 8, 2025 01:40
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 8, 2025
@gnecula gnecula force-pushed the weak_lu_cache branch 2 times, most recently from af081d1 to efc9edd Compare July 8, 2025 04:42
The function `pe._cached_abstract_eval` uses `util.cache`,
while most other cached functions in JAX use `weakref_lru_cache`.
The main difference is that `util.cache` keeps strong references to the function
arguments.

The modified test `lax_control_flow_test::test_cond_memory_leak` (added
a jit for one of the branches) is failing without this fix.
This is because the Jaxpr including the closed-over constant
leaks due to the strong references kept by the `util.cache` used in
`pe._cached_abstract_eval`.

We cannot use directly `weakref_lru_cache` because it keeps
weak references only to the first positional argument. We add here
a variant `multi_weakref_lru_cache` that uses weak references
for all the positional and keyword arguments for which
`util.is_weakref_cache_key_type` is true.
This is set for `Jaxpr`, `ClosedJaxpr` and `Callable`.
Eventually, we can decide to generalize the existing
`weakref_lru_cache` to have this behavior.
@gnecula gnecula requested a review from mwhittaker July 9, 2025 17:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants