Replies: 1 comment 3 replies
-
I use the following pattern - not sure if it is the best approach though - import jax.numpy as np
import jax
def f(x):
return np.mean(np.sin(x))
def grad_with_mask(func, mask=None):
def grad(*args, **kwargs):
out = jax.grad(func)(*args, **kwargs)
if mask is None:
return out
return jax.tree_map(lambda node, mask: np.where(mask, 0, node), out, mask)
return grad
print(jax.grad(f)(np.array([3.0, 1.0])))
# [-0.49499625 0.27015114]
print(grad_with_mask(f, mask=np.array([True, False]))(np.array([3.0, 1.0])))
# [0. 0.27015114] It can be used with pytrees too |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
If I have a function whose input is jax array, can I stop the gradient's w.r.t a few selected elements?
Ideally, it would be something like this.
This should return
[0.0, 0.27]
but instead it gives[0.0, 0.0]
.P.s. I am trying to do an operation like convolution on a 3d array. However, a few pixels in the middle are masked and have to be kept to a constant value. Ideally, the 3d colvolution should only affect the non-masked pixels, but I couldn't do that, so I am trying this approach where I simply cut off the gradients [so that they can never change]
Beta Was this translation helpful? Give feedback.
All reactions