Skip to content

Help with Custom VJPs #18199

Answered by froystig
karan-dalal asked this question in Q&A
Discussion options

You must be logged in to vote

Is there anyway to take the gradient with respect to so,e intermediate variable that we create inside a function?

Derivatives are really only defined with respect to the input arguments of a function. One trick to differentate "with respect to an intermediate expression" is to add a perturbation argument to the expression.

Here's an example. Say I have some function:

def f(x):
  a = jnp.exp(x)
  b = jnp.sin(a)
  c = b * 5.
  return c

And I want to differentiate it "with respect to b," at the point of b's value implied by the value of x. I can rewrite:

def f_pert(x, z):
  a = jnp.exp(x)
  b = jnp.sin(a) + z
  c = b * 5.
  return c

and take the derivative with respect to the perturbation …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by karan-dalal
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants