Replies: 1 comment
-
Have you figured it out? Is writing forward+backward w/ custom_vjp or custom_jvp required or not? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Reading this section in the docs I would expect that calling
jax.grad
on a Pallas call would work – but potentially with a performance hit. However, using the following snippet:Results in an
AssertionError
:What is happening here? Is this some issue with using
interpret=True
? Is the expected usage to define different pallas kernels for fwd and bwd usingjax.custom_jvp
?Beta Was this translation helpful? Give feedback.
All reactions