Skip to content

Commit bed4f65

Browse files
author
jax authors
committed
Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1. PiperOrigin-RevId: 618195554
1 parent c82deb2 commit bed4f65

File tree

14 files changed

+194
-161
lines changed

14 files changed

+194
-161
lines changed

.bazelrc

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -228,30 +228,6 @@ build:rbe_linux_cuda_base --config=rbe_linux
228228
build:rbe_linux_cuda_base --config=cuda
229229
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
230230

231-
build:rbe_linux_cuda11.8_nvcc_base --config=rbe_linux_cuda_base
232-
build:rbe_linux_cuda11.8_nvcc_base --config=cuda_clang
233-
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_NVCC_CLANG="1"
234-
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDA_VERSION=11
235-
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDNN_VERSION=8
236-
build:rbe_linux_cuda11.8_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8"
237-
build:rbe_linux_cuda11.8_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
238-
build:rbe_linux_cuda11.8_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain"
239-
build:rbe_linux_cuda11.8_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain"
240-
build:rbe_linux_cuda11.8_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain-linux-x86_64"
241-
build:rbe_linux_cuda11.8_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
242-
build:rbe_linux_cuda11.8_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
243-
build:rbe_linux_cuda11.8_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
244-
build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda"
245-
build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_nccl"
246-
build:rbe_linux_cuda11.8_nvcc_py3.9 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.9"
247-
build:rbe_linux_cuda11.8_nvcc_py3.9 --python_path="/usr/local/bin/python3.9"
248-
build:rbe_linux_cuda11.8_nvcc_py3.10 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.10"
249-
build:rbe_linux_cuda11.8_nvcc_py3.10 --python_path="/usr/local/bin/python3.10"
250-
build:rbe_linux_cuda11.8_nvcc_py3.11 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.11"
251-
build:rbe_linux_cuda11.8_nvcc_py3.11 --python_path="/usr/local/bin/python3.11"
252-
build:rbe_linux_cuda11.8_nvcc_py3.12 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.12"
253-
build:rbe_linux_cuda11.8_nvcc_py3.12 --python_path="/usr/local/bin/python3.12"
254-
255231
build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
256232
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang
257233
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1"

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
3232

3333
## jaxlib 0.4.26
3434

35+
* Changes
36+
* JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been
37+
dropped.
38+
3539
## jax 0.4.25 (Feb 26, 2024)
3640

3741
* New Features

docs/installation.md

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ NVIDIA has dropped support for Kepler GPUs in its software.
6161

6262
You must first install the NVIDIA driver. We
6363
recommend installing the newest driver available from NVIDIA, but the driver
64-
must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
64+
must be version >= 525.60.13 for CUDA 12 on Linux.
6565
If you need to use a newer CUDA toolkit with an older driver, for example
6666
on a cluster where you cannot update the NVIDIA driver easily, you may be
6767
able to use the
@@ -82,10 +82,6 @@ pip install --upgrade pip
8282
# CUDA 12 installation
8383
# Note: wheels only available on linux.
8484
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
85-
86-
# CUDA 11 installation
87-
# Note: wheels only available on linux.
88-
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
8985
```
9086

9187
If JAX detects the wrong version of the CUDA libraries, there are several things
@@ -113,14 +109,19 @@ able to use the
113109
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
114110
that NVIDIA provides for this purpose.
115111

116-
JAX currently ships two CUDA wheel variants:
117-
* CUDA 12.3, cuDNN 8.9, NCCL 2.16
118-
* CUDA 11.8, cuDNN 8.6, NCCL 2.16
112+
JAX currently ships one CUDA wheel variant:
113+
114+
| Built with | Compatible with |
115+
|------------|-----------------|
116+
| CUDA 12.3 | CUDA 12.1+ |
117+
| cuDNN 8.9 | cuDNN 8.9+ |
118+
| NCCL 2.19 | NCCL 2.18+ |
119119

120-
You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL
121-
installations match, and the minor versions are the same or newer.
122120
JAX checks the versions of your libraries, and will report an error if they are
123121
not sufficiently new.
122+
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
123+
the check, but using older versions of CUDA may lead to errors, or incorrect
124+
results.
124125

125126
NCCL is an optional dependency, required only if you are performing multi-GPU
126127
computations.
@@ -134,9 +135,6 @@ pip install --upgrade pip
134135
# Note: wheels only available on linux.
135136
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
136137

137-
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
138-
# Note: wheels only available on linux.
139-
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
140138
```
141139

142140
**These `pip` installations do not work with Windows, and may fail silently; see
@@ -188,11 +186,6 @@ pip install -U --pre libtpu-nightly -f https://storage.googleapis.com/jax-releas
188186
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
189187
```
190188

191-
* Jaxlib GPU (Cuda 11):
192-
```bash
193-
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
194-
```
195-
196189
## Google TPU
197190

198191
### pip installation: Google Cloud TPU

docs/tutorials/installation.md

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ NVIDIA has dropped support for Kepler GPUs in its software.
7272

7373
You must first install the NVIDIA driver. You're
7474
recommended to install the newest driver available from NVIDIA, but the driver
75-
version must be >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
75+
version must be >= 525.60.13 for CUDA 12 on Linux.
7676

7777
If you need to use a newer CUDA toolkit with an older driver, for example
7878
on a cluster where you cannot update the NVIDIA driver easily, you may be
@@ -99,10 +99,6 @@ pip install --upgrade pip
9999
# NVIDIA CUDA 12 installation
100100
# Note: wheels only available on linux.
101101
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
102-
103-
# NVIDIA CUDA 11 installation
104-
# Note: wheels only available on linux.
105-
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
106102
```
107103

108104
If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things
@@ -131,15 +127,19 @@ able to use the
131127
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
132128
that NVIDIA provides for this purpose.
133129

134-
JAX currently ships two NVIDIA CUDA wheel variants:
130+
JAX currently ships one CUDA wheel variant:
135131

136-
- CUDA 12.2, cuDNN 8.9, NCCL 2.16
137-
- CUDA 11.8, cuDNN 8.6, NCCL 2.16
132+
| Built with | Compatible with |
133+
|------------|-----------------|
134+
| CUDA 12.3 | CUDA 12.1+ |
135+
| cuDNN 8.9 | cuDNN 8.9+ |
136+
| NCCL 2.19 | NCCL 2.18+ |
138137

139-
You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL
140-
installations match, and the minor versions are the same or newer.
141138
JAX checks the versions of your libraries, and will report an error if they are
142139
not sufficiently new.
140+
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
141+
the check, but using older versions of CUDA may lead to errors, or incorrect
142+
results.
143143

144144
NCCL is an optional dependency, required only if you are performing multi-GPU
145145
computations.
@@ -152,10 +152,6 @@ pip install --upgrade pip
152152
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer.
153153
# Note: wheels only available on linux.
154154
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
155-
156-
# Installs the wheel compatible with NVIDIA CUDA 11 and cuDNN 8.6 or newer.
157-
# Note: wheels only available on linux.
158-
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
159155
```
160156

161157
**These `pip` installations do not work with Windows, and may fail silently; refer to the table
@@ -212,12 +208,6 @@ pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/lib
212208
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
213209
```
214210

215-
- `jaxlib` NVIDIA GPU (CUDA 11):
216-
217-
```bash
218-
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
219-
```
220-
221211
(install-google-tpu)=
222212
## Google Cloud TPU
223213

@@ -318,4 +308,4 @@ pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_re
318308
For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example
319309
```bash
320310
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
321-
```
311+
```

jax/_src/xla_bridge.py

Lines changed: 112 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import os
3232
import pkgutil
3333
import platform as py_platform
34+
import traceback
3435
import sys
3536
import threading
3637
from typing import Any, Callable, Union
@@ -267,33 +268,101 @@ def _check_cuda_compute_capability(devices_to_check):
267268
RuntimeWarning
268269
)
269270

270-
def _check_cuda_versions():
271+
272+
def _check_cuda_versions(raise_on_first_error: bool = False,
273+
debug: bool = False):
271274
assert cuda_versions is not None
275+
results: list[dict[str, Any]] = []
276+
277+
def _make_msg(name: str,
278+
runtime_version: int,
279+
build_version: int,
280+
min_supported: int,
281+
debug_msg: bool = False):
282+
if debug_msg:
283+
return (f"Package: {name}\n"
284+
f"Version JAX was built against: {build_version}\n"
285+
f"Minimum supported: {min_supported}\n"
286+
f"Installed version: {runtime_version}")
287+
if min_supported:
288+
req_str = (f"The local installation version must be no lower than "
289+
f"{min_supported}.")
290+
else:
291+
req_str = ("The local installation must be the same version as "
292+
"the version against which JAX was built.")
293+
msg = (f"Outdated {name} installation found.\n"
294+
f"Version JAX was built against: {build_version}\n"
295+
f"Minimum supported: {min_supported}\n"
296+
f"Installed version: {runtime_version}\n"
297+
f"{req_str}")
298+
return msg
299+
300+
def _version_check(name: str,
301+
get_version,
302+
get_build_version,
303+
scale_for_comparison: int = 1,
304+
min_supported_version: int = 0):
305+
"""Checks the runtime CUDA component version against the JAX one.
306+
307+
Args:
308+
name: Of the CUDA component.
309+
get_version: A function to get the local runtime version of the component.
310+
get_build_version: A function to get the build version of the component.
311+
scale_for_comparison: For rounding down a version to ignore patch/minor.
312+
min_supported_version: An absolute minimum version required. Must be
313+
passed without rounding down.
314+
315+
Raises:
316+
RuntimeError: If the component is not found, or is of unsupported version,
317+
and if raising the error is not deferred till later.
318+
"""
272319

273-
def _version_check(name, get_version, get_build_version,
274-
scale_for_comparison=1):
275320
build_version = get_build_version()
276321
try:
277322
version = get_version()
278323
except Exception as e:
279-
raise RuntimeError(f"Unable to load {name}. Is it installed?") from e
280-
if build_version // scale_for_comparison > version // scale_for_comparison:
281-
raise RuntimeError(
282-
f"Found {name} version {version}, but JAX was built against version "
283-
f"{build_version}, which is newer. The copy of {name} that is "
284-
"installed must be at least as new as the version against which JAX "
285-
"was built."
286-
)
324+
err_msg = f"Unable to load {name}. Is it installed?"
325+
if raise_on_first_error:
326+
raise RuntimeError(err_msg) from e
327+
err_msg += f"\n{traceback.format_exc()}"
328+
results.append({"name": name, "installed": False, "msg": err_msg})
329+
return
330+
331+
if not min_supported_version:
332+
min_supported_version = build_version // scale_for_comparison
333+
passed = min_supported_version <= version
334+
335+
if not passed or debug:
336+
msg = _make_msg(name=name,
337+
runtime_version=version,
338+
build_version=build_version,
339+
min_supported=min_supported_version,
340+
debug_msg=passed)
341+
if not passed and raise_on_first_error:
342+
raise RuntimeError(msg)
343+
else:
344+
record = {"name": name,
345+
"installed": True,
346+
"msg": msg,
347+
"passed": passed,
348+
"build_version": build_version,
349+
"version": version,
350+
"minimum_supported": min_supported_version}
351+
results.append(record)
287352

288353
_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
289-
cuda_versions.cuda_runtime_build_version)
354+
cuda_versions.cuda_runtime_build_version,
355+
scale_for_comparison=10,
356+
min_supported_version=12010)
290357
_version_check(
291358
"cuDNN",
292359
cuda_versions.cudnn_get_version,
293360
cuda_versions.cudnn_build_version,
294361
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
295-
# versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
362+
# versions:
363+
# https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
296364
scale_for_comparison=100,
365+
min_supported_version=8900
297366
)
298367
_version_check("cuFFT", cuda_versions.cufft_get_version,
299368
cuda_versions.cufft_build_version,
@@ -302,20 +371,42 @@ def _version_check(name, get_version, get_build_version,
302371
_version_check("cuSOLVER", cuda_versions.cusolver_get_version,
303372
cuda_versions.cusolver_build_version,
304373
# Ignore patch versions.
305-
scale_for_comparison=100)
374+
scale_for_comparison=100,
375+
min_supported_version=11400)
306376
_version_check("cuPTI", cuda_versions.cupti_get_version,
307-
cuda_versions.cupti_build_version)
377+
cuda_versions.cupti_build_version,
378+
min_supported_version=18)
308379
# TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21
309380
if hasattr(cuda_versions, "cublas_get_version"):
310381
_version_check("cuBLAS", cuda_versions.cublas_get_version,
311382
cuda_versions.cublas_build_version,
312383
# Ignore patch versions.
313-
scale_for_comparison=100)
384+
scale_for_comparison=100,
385+
min_supported_version=120100)
314386
if hasattr(cuda_versions, "cusparse_get_version"):
315387
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
316388
cuda_versions.cusparse_build_version,
317389
# Ignore patch versions.
318-
scale_for_comparison=100)
390+
scale_for_comparison=100,
391+
min_supported_version=12100)
392+
393+
errors = []
394+
debug_results = []
395+
for result in results:
396+
message: str = result['msg']
397+
if not result['installed'] or not result['passed']:
398+
errors.append(message)
399+
else:
400+
debug_results.append(message)
401+
402+
join_str = f'\n{"-" * 50}\n'
403+
if debug_results:
404+
print(f'CUDA components status (debug):\n'
405+
f'{join_str.join(debug_results)}')
406+
if errors:
407+
raise RuntimeError(f'Unable to use CUDA because of the '
408+
f'following issues with CUDA components:\n'
409+
f'{join_str.join(errors)}')
319410

320411

321412
def make_gpu_client(
@@ -335,6 +426,10 @@ def make_gpu_client(
335426
if platform_name == "cuda":
336427
if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"):
337428
_check_cuda_versions()
429+
else:
430+
print('Skipped CUDA versions constraints check due to the '
431+
'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.')
432+
338433
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
339434
if jaxlib.version.__version_info__ >= (0, 4, 26):
340435
devices_to_check = (allowed_devices if allowed_devices else

jax_plugins/cuda/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
2727
# preinstalled jax cuda plugin packages.
28-
for pkg_name in ['jax_cuda12_plugin', 'jax_cuda11_plugin', 'jaxlib']:
28+
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
2929
try:
3030
cuda_plugin_extension = importlib.import_module(
3131
f'{pkg_name}.cuda_plugin_extension'

jaxlib/gpu_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from jaxlib import xla_client
2626

27-
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
27+
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
2828
try:
2929
_cuda_linalg = importlib.import_module(
3030
f"{cuda_module_name}._linalg", package="jaxlib"

jaxlib/gpu_prng.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .hlo_helpers import custom_call
2828
from .gpu_common_utils import GpuLibNotLinkedError
2929

30-
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
30+
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
3131
try:
3232
_cuda_prng = importlib.import_module(
3333
f"{cuda_module_name}._prng", package="jaxlib"

0 commit comments

Comments
 (0)