KeyError: dtype([('float0', 'V')]) grad + odeint in JAX 0.3.3 #9951
-
Hi all, I've just upgraded to 0.3.3 and am getting a new error I've never seen before in some code that executed correctly in previous versions. I don't yet have a minimal example, but wanted to check in here if anything rings a bell, as the error message is pretty uninformative (to me). The exception being raised is extremely long, but it is coming up when trying to take the
|
Beta Was this translation helpful? Give feedback.
Replies: 7 comments 32 replies
-
Thanks for raising this. I have a guess... but is there a chance you could share a runnable repro so I can test my guess? |
Beta Was this translation helpful? Give feedback.
-
Has there been some significant change to the odeint method (or changes that directly affect its running) between v0.3.1 and v0.3.3, as some code I have now runs 50% slower? |
Beta Was this translation helpful? Give feedback.
-
(Starting a new unified top-level thread.) FYI @oracle3001 @DanPuzzuoli we pushed a new pypi version of jax==0.3.4 with the #9959 fix merged in. |
Beta Was this translation helpful? Give feedback.
-
I'm getting a very similar error message, and this discussion was the only hit on google when searching. But I've got jax 0.3.13 installed, which has the fix discussed above. For anyone else encountering this, the problem is caused by numpy 1.23.0 and can be worked around by downgrading numpy. See #11221 |
Beta Was this translation helpful? Give feedback.
-
Hello everyone,
The output is the following:
I am using |
Beta Was this translation helpful? Give feedback.
-
Hi @pragati903 how did you solve the problem? Having the same issue with jax 0.4.13 and numpy 1.23 and downgrading numpy to <= 1.22 didn't solve the issue. Weird is my code used to run fine until just a few months ago. |
Beta Was this translation helpful? Give feedback.
-
Hello, I am getting this error again on Python 3.10, Jax 4.28:
The issue in my case is when one of the *args passed to func in odeint is an integer. For some reason, this causes the above error when backproping through the odeint. In my case, I am able to get around this error by passing the arg as a float, and then casting it to int inside of func using
Here is the docstring for odeint for reference:
|
Beta Was this translation helpful? Give feedback.
Thanks for raising this. I have a guess... but is there a chance you could share a runnable repro so I can test my guess?