types changing in while_loop(), need help! #11049
Unanswered
derekpowell
asked this question in
General
Replies: 1 comment 3 replies
-
The error message says input ('ShapedArray(float32[11])', # summand
'DIFFERENT ShapedArray(float32[11]) vs. ShapedArray(float32[])', # sum_numer
'DIFFERENT ShapedArray(float32[11]) vs. ShapedArray(float32[])', # sum_denom
'ShapedArray(float32[])', # a_plus_b
'ShapedArray(int32[], weak_type=True)', # k
'ShapedArray(float32[])', # a_plus_1
'ShapedArray(float32[])', # digamma_ab
'ShapedArray(float32[])', # digamma_a
'ShapedArray(float32[11])') # x Maybe you could also share the code that produces this error? |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm fumbling my way through implementing some custom jvp for the regularized incomplete beta function by attempting to port over some code from the Stan project (this and this). The functions use a while loop that I am trying to implement with
jax.lax.while_loop()
. My implementations seem to work in some contexts but not others. When I try to run this inside a numpyro MCMC sampling context, I get the following error:I gather I am somehow transforming the types inside my
body_fun
(_betainc_dda_while
below) but I can't seem to figure out where. Hoping for some help debugging or even reading this error message, I can't see where the issue is arising.here's my code:
Appreciate any help!
Beta Was this translation helpful? Give feedback.
All reactions