Skip to content

Clipping changing the gradients #29929

Answered by jakevdp
SNMS95 asked this question in Show and tell
Jul 2, 2025 · 1 comments · 3 replies
Discussion options

You must be logged in to vote

Hi - thanks for the question! I think this is working more-or-less as expected. clip is a function with a discontinuous gradient, and you're asking for the gradient at the precise point of the discontinuity. There are four reasonable choices for the gradient of clip(x, 0, 1) evaluated at x = 1.0:

  1. taking a symmetric limit would give you NaN
  2. taking the limit from the left would give you 1.0 (this would result in both your gradients being equal)
  3. taking the limit from the right would give you 0.0 (this is what is done by jax.lax.clamp for the same case)
  4. to make the gradient equivalent under argument permutation, we can return 0.5 (this is what is done by jnp.clip, via the autodiff rule of ja…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@jakevdp
Comment options

@SNMS95
Comment options

@jakevdp
Comment options

Answer selected by SNMS95
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants