How to solve the TypeError("primal and tangent arguments to jax.jvp do not match) when one of the primal's dtype is int32? #17321
Unanswered
AndrewLiu0725
asked this question in
Q&A
Replies: 1 comment
-
This looks similar to the issue addressed in #14570, but I'm not sure of the easiest way to do what you have in mind. I think passing |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a neural network that takes discretized inputs, specifically an array of integers. When I attempted to use custom_jvp to craft a customized gradient (with respect to model parameters) calculation for my loss function, I encountered the error message:
This error emerged because one of the primal's dtypes is int32. I am unsure how to resolve this type of mismatch error. Should I manually generate a float0 array that matches the shape of the primal/tangent and then use it in place of the jvp, given that I am only taking derivatives with respect to the model parameters? Any assistance is greatly appreciated.
Below is a minimal reproducible code that triggers the error:
Beta Was this translation helpful? Give feedback.
All reactions