Skip to content

A concise way to define a custom jvp rule for a wrappers with many inputs and defaults? #27669

Answered by dfm
lankef asked this question in Q&A
Discussion options

You must be logged in to vote

In this case, you want to use the nondiff_argnums argument to custom_jvp:

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.custom_jvp, nondiff_argnums=(2,))
def f(x, y, config):
  print(config['test'])
  return jnp.sin(x) * y + config['constant']

@f.defjvp
def f_jvp(config, primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y, config)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out

jax.grad(f)(1., 2., {'test': 'good', 'constant':0.1})

Note that this also changes the signature of f_jvp. Does that do the trick in your case?

Replies: 1 comment 5 replies

Comment options

dfm
Apr 2, 2025
Collaborator

You must be logged in to vote
5 replies
@lankef
Comment options

@dfm
Comment options

dfm Apr 2, 2025
Collaborator

@lankef
Comment options

@dfm
Comment options

dfm Apr 2, 2025
Collaborator

Answer selected by lankef
@lankef
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants