Skip to content

Performance of jax.grad when taking the derivative with respect to many inputs #26562

Answered by jakevdp
EmilianoG-byte asked this question in General
Discussion options

You must be logged in to vote

The reason for this is likely that the gradient with respect to a single element requires just about as much computation as the gradient with respect to all three, because intermediate computations will be reused in the latter case. We can show this with a simplified example by looking at the jaxpr of the computation:

import jax

def f(x, y):
  return jnp.sin(x) * jnp.exp(y)

x, y = 2.0, 3.0

print(jax.make_jaxpr(jax.grad(f, argnums=0))(x, y))
# { lambda ; a:f64[] b:f64[]. let
#     c:f64[] = sin a
#     d:f64[] = cos a
#     e:f64[] = exp b
#     _:f64[] = mul c e
#     f:f64[] = mul 1.0 e
#     g:f64[] = mul f d
#   in (g,) }

print(jax.make_jaxpr(jax.grad(f, argnums=(0, 1)))(x, y))
# {…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@EmilianoG-byte
Comment options

@jakevdp
Comment options

Answer selected by EmilianoG-byte
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants