Things compile twice... why? And how to track cache hits/misses? #27186
-
TL;DR: I had assumed that SituationI was trying to understand which functions were being jitted, and how often, in a situation like:
In my case, an iterative algorithm can either jit the iteration steps or not, depending on an argument (no jit is helpful for debugging, jit is faster in real problems). I solved a performance issue by looking at the debug logs: Nevertheless, I was surprised to find that, if related logs from my project
Relatedly, multiple compilations of the same function were yielding exactly two distinct python objects. Here's a MWE: MWEfrom typing import Callable
from jax import jit
def foo():
return 1
def do_foo(func: Callable, use_jit=True):
if use_jit:
func = jit(func)
func()
print(id(func))
do_foo(foo)
do_foo(foo)
do_foo(foo)
do_foo(foo)
print(jit(foo)._cache_size()) 126667197689392 HOWEVER, this changed when I changed information around the call to def do_foo(func: Callable, use_jit=True):
if use_jit:
func = jit(func)
func()
print(id(func))
return func Running
What's going on here?Now there's four objects, one of which existed in the previous invocation, for five objects total (so it's not one per device, of which I have 4). I had previously thought |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The JIT cache is in play here, and as you've seen you end up with only one compilation (which can be tracked via |
Beta Was this translation helpful? Give feedback.
The JIT cache is in play here, and as you've seen you end up with only one compilation (which can be tracked via
func._cache_size()
). You're correct that in both cases you've created four distinct Python function objects with different object ids, but each of these when called will execute the same cached XLA operation, and thus avoid recompilation. The JIT cache is tied to the id of the original, wrapped function, not the functions returned byjit
. I hope that's clear!