Skip to content

How to compute first & second derivatives of NN with respect to the input variable? #23021

Answered by dfm
galah92 asked this question in Q&A
Discussion options

You must be logged in to vote

In this case you could define a batched version of that function by moving the vmap inside as follows:

def first_and_second_derivative_batched(f):
  def wrapped(x):
    dx = jnp.ones_like(x)
    df, df2 = jax.jvp(jax.vmap(jax.grad(f)), (x,), (dx,))
    return df, df2
  return wrapped

x = jnp.linspace(-1, 1, 5)
dfdx, df2dx2 = first_and_second_derivative_batched(fun)(x)
dfdx - jnp.cos(x), df2dx2 + jnp.sin(x)

Depending on that shape of the output of f, you might need to squeeze it, e.g.:

f_ = lambda x: f(x).reshape(x.shape[0])

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@galah92
Comment options

@dfm
Comment options

dfm Aug 13, 2024
Collaborator

Answer selected by galah92
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants