Skip to content

Does JAX come with XLA binaries that I can use? #18496

Answered by hawkinsp
hayden-donnelly asked this question in Q&A
Discussion options

You must be logged in to vote

What you want is buried inside xla_extension.so, but not with an interface that is useful to you most likely (since that module exposes a Python API).

There is a better way, which is that the XLA-PJRT runtime that JAX uses has a C API: https://github.com/openxla/xla/blob/main/xla/pjrt/c/README.md

Some JAX plugins already use this API (Google TPU, Apple Metal, to name two). We are planning to shortly update JAX's NVIDIA GPU plugin to also use this API, and at that point JAX will have a build of that plugin API as a .so file.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@hayden-donnelly
Comment options

Answer selected by hayden-donnelly
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants