-
I am wondering which one is better for implementing jitted function.
Which one is preferred for expert jax developer? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Oct 19, 2023
Replies: 1 comment 1 reply
-
It doesn't matter, JAX will generate the same sequence of lowered operations either way. You can see this by generating the jaxpr for each approach: import jax
import jax.numpy as jnp
import functools
def create_f1(val: int):
@jax.jit
def f(x: jnp.ndarray):
return val * x
return f
f1 = create_f1(2)
def create_f2(val: int):
@functools.partial(jax.jit, static_argnames=('val'))
def _f(x: jnp.ndarray, val: int):
return val * x
def f(x: jnp.ndarray):
return _f(x, val)
return f
f2 = create_f2(2)
x = jnp.arange(10)
print(jax.make_jaxpr(f1)(x))
# { lambda ; a:i32[10]. let
# b:i32[10] = pjit[
# jaxpr={ lambda ; c:i32[10]. let d:i32[10] = mul 2 c in (d,) }
# name=f
# ] a
# in (b,) }
print(jax.make_jaxpr(f2)(x))
# { lambda ; a:i32[10]. let
# b:i32[10] = pjit[
# jaxpr={ lambda ; c:i32[10]. let d:i32[10] = mul 2 c in (d,) }
# name=_f
# ] a
# in (b,) } |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
sh0416
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It doesn't matter, JAX will generate the same sequence of lowered operations either way. You can see this by generating the jaxpr for each approach: