Replies: 1 comment 10 replies
-
Yes, import jax
def f(x, flag):
return x if flag else x + 1
f_jit = jax.jit(f, static_argnums=1)
print(f_jit._cache_size()) # 0
f_jit(1.0, True)
print(f_jit._cache_size()) # 1
f_jit(1.0, False)
print(f_jit._cache_size()) # 2
# re-wrapped function hits the same cache
f_jit_2 = jax.jit(f, static_argnums=1)
print(f_jit_2._cache_size()) # 2
# cache hits don't increase the cache size
f_jit(100.0, True)
f_jit(100.0, False)
print(f_jit._cache_size()) # 2 |
Beta Was this translation helpful? Give feedback.
10 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I was searching the internet/documentation to see if
jax.jit
retains a cache for previously used static arguments to a function. I couldn't seem to find any information, so I did a quick test. It seems thatjax.jit
does retain the cache.Beta Was this translation helpful? Give feedback.
All reactions