Can Pallas Auto Diff? #19184
Unanswered
karan-dalal
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
The Pallas Docs state that transformations should work inside a Pallas kernel, using the following example with jax.grad:
I'm wondering if it would be possible to actually differentiate an entire model inside a kernel, since this would reduce to just matrix multiplications / other operations that could be executed on Mosiac / Triton.
For example, could I forward and backprop through a 2 Layer MLP with a non-linearity inside a Pallas kernel?
Beta Was this translation helpful? Give feedback.
All reactions