Skip to content

Static argnames vs function closure #18179

Answered by jakevdp
sh0416 asked this question in Q&A
Discussion options

You must be logged in to vote

It doesn't matter, JAX will generate the same sequence of lowered operations either way. You can see this by generating the jaxpr for each approach:

import jax
import jax.numpy as jnp
import functools

def create_f1(val: int):
  @jax.jit
  def f(x: jnp.ndarray):
    return val * x
  return f

f1 = create_f1(2)

def create_f2(val: int):
  @functools.partial(jax.jit, static_argnames=('val'))
  def _f(x: jnp.ndarray, val: int):
    return val * x
  
  def f(x: jnp.ndarray):
    return _f(x, val)
  return f

f2 = create_f2(2)

x = jnp.arange(10)

print(jax.make_jaxpr(f1)(x))
# { lambda ; a:i32[10]. let
#     b:i32[10] = pjit[
#       jaxpr={ lambda ; c:i32[10]. let d:i32[10] = mul 2 c in (d,) }

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@sh0416
Comment options

Answer selected by sh0416
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants