Skip to content

Commit 783d5d2

Browse files
Jieying Luojax authors
authored andcommitted
[PJRT C API] Plumb plugin attributes from plugin to JAX python.
Also add a method for the plugin to return an xla_version plugin attribute. Currently jaxlib pins a TPU/GPU backend, and uses `xla_extension_version` for backend version. As we want to stop pinning TPU/GPU backend and allow pip install different backend separately, we need this `xla_version` for features that are not capture by PJRT C API version. `xla_extension_version` will still be used for API changes such as xla_client.py, or any XLA changes in jaxlib that are not part of plugins. PiperOrigin-RevId: 621672421
1 parent 92326db commit 783d5d2

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

jax/_src/xla_bridge.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,18 @@ def backend_pjrt_c_api_version(platform=None) -> tuple[int, int] | None:
10651065
return None
10661066

10671067

1068+
def backend_xla_version(platform=None) -> int | None:
1069+
"""Returns the XLA version of the backend.
1070+
1071+
Returns None if the backend does not use PJRT C API or does not have
1072+
xla_version in the plugin attributes. This methon can be used to skip features
1073+
that are not available before certain xla_version if the backend is a
1074+
plugin and uses xla_version.
1075+
"""
1076+
backend = get_backend(platform)
1077+
return getattr(backend, "xla_version", None)
1078+
1079+
10681080
@lru_cache
10691081
def local_devices(process_index: int | None = None,
10701082
backend: str | xla_client.Client | None = None,

0 commit comments

Comments
 (0)