Skip to content

Commit bc363de

Browse files
Micky774jax authors
authored andcommitted
Copybara import of the project:
-- ac2c522 by Meekail Zain <zainmeekail@gmail.com>: [FIX] Added jaxlib version guard for CUDA compute capability check COPYBARA_INTEGRATE_REVIEW=#20237 from Micky774:add_version_guard ac2c522 PiperOrigin-RevId: 616925918
1 parent ee2631e commit bc363de

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

jax/_src/xla_bridge.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from jax._src.lib import xla_client
4747
from jax._src.lib import xla_extension
4848
from jax._src.lib import xla_extension_version
49+
from jax._src.lib import jaxlib
4950

5051
logger = logging.getLogger(__name__)
5152

@@ -333,8 +334,11 @@ def make_gpu_client(
333334
)
334335
if platform_name == "cuda":
335336
_check_cuda_versions()
336-
devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count())
337-
_check_cuda_compute_capability(devices_to_check)
337+
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
338+
if jaxlib.version.__version_info__ >= (0, 4, 26):
339+
devices_to_check = (allowed_devices if allowed_devices else
340+
range(cuda_versions.cuda_device_count()))
341+
_check_cuda_compute_capability(devices_to_check)
338342

339343
return xla_client.make_gpu_client(
340344
distributed_client=distributed.global_state.client,

0 commit comments

Comments
 (0)