Clipping changing the gradients #29929
-
In this code, if the clip operation does not do anything to the inputs. However, if we check the gradient with and without the clipping operation, we get two different values. Is this because max and min functions, which are used behind the scenes, get a gradient of 0.5 for the edge cases? This goes away, if we choose the input to not include the upper and lower limits! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi - thanks for the question! I think this is working more-or-less as expected.
From a strict mathematical perspective, (1) might be the best, but introduing NaNs into autodiff code can be problematic. So we choose one of the other three options, each of which are reasonable choices. What do you think? |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! I think this is working more-or-less as expected.
clip
is a function with a discontinuous gradient, and you're asking for the gradient at the precise point of the discontinuity. There are four reasonable choices for the gradient ofclip(x, 0, 1)
evaluated atx = 1.0
:NaN
1.0
(this would result in both your gradients being equal)0.0
(this is what is done byjax.lax.clamp
for the same case)jnp.clip
, via the autodiff rule ofja…