Skip to content

How to do value assertion gracefully in JAX? #28439

Answered by jakevdp
HeavyCrab asked this question in Q&A
Discussion options

You must be logged in to vote

Unfortunately, there's no supported way to do runtime value assertions in JAX. One main reason for this is that some accelerators supported by JAX literally have no path for raising errors at runtime.

You can do various imperfect workarounds using callbacks, but those are not recommended. Better would be to modify your programming style to avoid the need for runtime value assertions. It takes some getting used to, but we find it works pretty well in practice.

Another option is to use the jax.experimental.checkify module, which allows you to propagate value-based errors using a functional approach supported by JAX. It's still somewhat rough around the edges, but there's some ongoing work t…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@patrick-kidger
Comment options

@jakevdp
Comment options

@patrick-kidger
Comment options

Answer selected by HeavyCrab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants