Help with Custom VJPs #18199
-
I have a complex backward pass I'm trying to implement in JAX. Some points to resources / ideas would be much appreciated. I've reduced the scope of the problem to the component I am struggling most on: Let's say I have the function x, a model W, and some linear transformation matrix A. My forward pass is defined as follows (pseudocode): g(x, W):
z_hat = f(x, W) # Apply parameters of W to x
x_hat = A * z_hat # "Reconstruct" x from z_hat using A
loss = (x_hat - x).mean()
return loss
forward(x, W):
grad_fn = jax.value_and_grad(g, argnums = 1)
_, grad = grad_fn(x, W)
W_new = W - grad # Gradient descent to update W
z = f(x, W_new) # Apply new parameters to x to produce final output
return z Essentially, I'm doing a self-supervised reconstruction task to update my model. The context will most likely not make sense without the bigger picture, so let me know if I can clarify anything above. I'm attempting to get the gradient of the gradient with respect to x_hat. Essentially: Dgrad/Dx_hat. However, as you can see, x_hat is constructed in creating grad. Is there anyway to take the gradient with respect to so,e intermediate variable that we create inside a function? Any tips would be extremely helpful. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Derivatives are really only defined with respect to the input arguments of a function. One trick to differentate "with respect to an intermediate expression" is to add a perturbation argument to the expression. Here's an example. Say I have some function: def f(x):
a = jnp.exp(x)
b = jnp.sin(a)
c = b * 5.
return c And I want to differentiate it "with respect to def f_pert(x, z):
a = jnp.exp(x)
b = jnp.sin(a) + z
c = b * 5.
return c and take the derivative with respect to the perturbation argument at 0: jax.grad(f_pert, argnums=1)(..., 0.) |
Beta Was this translation helpful? Give feedback.
Derivatives are really only defined with respect to the input arguments of a function. One trick to differentate "with respect to an intermediate expression" is to add a perturbation argument to the expression.
Here's an example. Say I have some function:
And I want to differentiate it "with respect to
b
," at the point ofb
's value implied by the value ofx
. I can rewrite:and take the derivative with respect to the perturbation …