gradient question about GatherScatterMode #17675
-
Hello, In https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.GatherScatterMode, it says (1) I want to know If gradients will be correct or not when indices are out-of-bounds in modes other than PROMISE_IN_BOUNDS ? These two questions above are the same question to some degree... |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Beta Was this translation helpful? Give feedback.
PROMISE_IN_BOUNDS
mode has incorrect gradients. The other modes will have correct gradients.PROMISE_IN_BOUNDS
does is up to the implementation, but it's is equivalent toCLIP
in the forwards direction, while its gradient is equivalent toDROP
. (These don't match, which is why the result isn't correct.)