Skip to content

Handling random keys with higher order derivatives and a stochastic custom_vjp #18085

Answered by patrick-kidger
atiyo asked this question in Q&A
Discussion options

You must be logged in to vote

I think what you'll need to do to accomplish this is put another custom_vjp inside your existing custom_vjp, giving the desired behavior (pass in a key and use this to generate uncorrelated second-order gradients).

For arbitrary-order gradients, you'd want to set up some kind of recursive procedure.

If it is easier for you to work through the mathematics, note that you can use a custom_jvp instead of a custom_vjp here. JAX will automatically synthesise the VJP from the JVP (deterministically and via transposition).

This aside, I'm quite curious what you're up to, that needs uncorrelated gradients? Essentially every modern use-case I know of (GANs, differentiating through SDE solves, ...) …

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
1 reply
@atiyo
Comment options

Comment options

You must be logged in to vote
2 replies
@atiyo
Comment options

@patrick-kidger
Comment options

Answer selected by atiyo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants