Handling random keys with higher order derivatives and a stochastic custom_vjp #18085
-
I have a stochastic function I can get this working nicely with a first-order derivative with a Custom JVP implementationI have something like this, which works for first order derivatives. @jax.custom_vjp
def f(fwd_key,...):
return foo
def grad_f(grad_key, ...):
return bar
def f_fwd(key,...):
fwd_key, grad_key = random.split(key)
return f(fwd_key,...), grad_f(grad_key,...)
def f_bwd(res, g):
return baz
f.defvjp(f_fwd, f_bwd)
key = random.PRNGkey(0)
jax.grad(f)(key, ...) # works like a dream The problemUsing this key = random.PRNGkey(0)
jax.grad(jax.grad(f))(key, ...) # will be reusing the same key in each backward pass. A working exampleimport jax
from jax import random
@jax.custom_vjp
def f(fwd_key, x):
"""A sample from a distribution of quadratic functions."""
noise = random.uniform(fwd_key, (1,))[0]
noisy_x = x * noise
return noisy_x + noisy_x ** 2
def grad_f(grad_key, x):
"""A sample of a noisy gradient from the noisy quadratic functions.
Note, in particular, that the derivative of the function is desired to be
from a different sample, not the same as the the one used in the forward
pass."""
noise = random.uniform(grad_key, (1,))[0]
noisy_x = x * noise
return 1.0 + 2 * noisy_x
def f_fwd(key, x):
fwd_key, grad_key = random.split(key)
return f(fwd_key, x), grad_f(grad_key, x)
def f_bwd(res, g):
return None, res * g
f.defvjp(f_fwd, f_bwd)
key = random.PRNGKey(0)
jax.value_and_grad(f, argnums=1)(key, 1.0) # 0.86768 and 1.2107: Good! These are uncorrelated :)
jax.value_and_grad(jax.grad(f, argnums=1), argnums=1)( key, 1.0) # 1.2107 and 0.2107: Bad, these are correlated :( Note that the forward pass and first-order gradient calculations are uncorrelated (penultimate line of the code above): this is what I want to achieve for all order gradients. But the second-order gradient and first order gradients are directly correlated (last line of the code above): I do not want this, I want another uncorrelated sample of a second-order gradient. This is because the same key is being reused in the calculation of the first order and second order gradients. So I'd love help with finding a way for nested grad calls to be able to use different random keys. An undesirable workaroundMaintain random keys globally and access these from within QuestionIs there any way to deal use custom_vjps needing random keys for higher order derivatives with random in the manner above without resorting to global random keys? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Could you edit your example code to something that is executable, rather than incomplete pseudocode, and then show the output of the second order function and how it's different than what you expect it to be? I'm having trouble filling in the missing pieces between your description and your pseudocode. |
Beta Was this translation helpful? Give feedback.
-
I think what you'll need to do to accomplish this is put another For arbitrary-order gradients, you'd want to set up some kind of recursive procedure. If it is easier for you to work through the mathematics, note that you can use a This aside, I'm quite curious what you're up to, that needs uncorrelated gradients? Essentially every modern use-case I know of (GANs, differentiating through SDE solves, ...) prefers correlated gradients. Off the top of my head the only case where I've seen uncorrelated gradients is Malliavin calculus, which AFAIK is largely superseded by the correlated-gradient approach wherever that's possible. |
Beta Was this translation helpful? Give feedback.
I think what you'll need to do to accomplish this is put another
custom_vjp
inside your existingcustom_vjp
, giving the desired behavior (pass in a key and use this to generate uncorrelated second-order gradients).For arbitrary-order gradients, you'd want to set up some kind of recursive procedure.
If it is easier for you to work through the mathematics, note that you can use a
custom_jvp
instead of acustom_vjp
here. JAX will automatically synthesise the VJP from the JVP (deterministically and via transposition).This aside, I'm quite curious what you're up to, that needs uncorrelated gradients? Essentially every modern use-case I know of (GANs, differentiating through SDE solves, ...) …