AssertionError when computing gradient of a function with reduce_window #24754
Unanswered
LeonardoTredese
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Based on your error, you may noticed import jax
x = jax.numpy.ones((1, 28, 28))
def f(x):
return jax.nn.max_pool(x, (2, 2), (2, 2), padding="SAME").sum()
gradient = jax.grad(f)(x)
print(gradient)
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
When running the following example
I get the following assertion error:
I am using Python 3.11.6 and ubuntu 23.10.
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35
I believe this is not an expected behaviour. Am I wrong? How can I get the gradients of this function?
Beta Was this translation helpful? Give feedback.
All reactions