Skip to content

Commit e83d9b5

Browse files
committed
Make GPU CUDA plugin require JAX
Some XLA GPU features require JAX. Rather than only installing the latest version of JAX in CI, we'll just make the CUDA plugin depend on a version of JAX that's the same as what's used by PyTorch/XLA on TPU. (Except the JAX CUDA wheels).
1 parent cba9ff9 commit e83d9b5

File tree

4 files changed

+137
-62
lines changed

4 files changed

+137
-62
lines changed

.github/workflows/_test_requiring_torch_cuda.yml

-8
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,6 @@ jobs:
8787
uses: actions/checkout@v4
8888
with:
8989
path: pytorch/xla
90-
- name: Extra CI deps
91-
shell: bash
92-
run: |
93-
set -x
94-
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
95-
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
96-
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
97-
if: ${{ matrix.run_triton_tests }}
9890
- name: Install Triton
9991
shell: bash
10092
run: |

build_util.py

+124-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,129 @@
33
import subprocess
44
import sys
55
import shutil
6+
from dataclasses import dataclass
7+
import functools
8+
9+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10+
11+
12+
@functools.lru_cache
13+
def get_pinned_packages():
14+
"""Gets the versions of important pinned dependencies of torch_xla."""
15+
return PinnedPackages(
16+
use_nightly=True,
17+
date='20250320',
18+
raw_libtpu_version='0.0.12',
19+
raw_jax_version='0.5.4',
20+
raw_jaxlib_version='0.5.4',
21+
)
22+
23+
24+
@functools.lru_cache
25+
def get_build_version():
26+
xla_git_sha, _torch_git_sha = get_git_head_sha(BASE_DIR)
27+
version = os.getenv('TORCH_XLA_VERSION', '2.8.0')
28+
if check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'):
29+
try:
30+
version += '+git' + xla_git_sha[:7]
31+
except Exception:
32+
pass
33+
return version
34+
35+
36+
@functools.lru_cache
37+
def get_git_head_sha(base_dir):
38+
xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
39+
cwd=base_dir).decode('ascii').strip()
40+
if os.path.isdir(os.path.join(base_dir, '..', '.git')):
41+
torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
42+
cwd=os.path.join(
43+
base_dir,
44+
'..')).decode('ascii').strip()
45+
else:
46+
torch_git_sha = ''
47+
return xla_git_sha, torch_git_sha
48+
49+
50+
def get_jax_cuda_requirements():
51+
"""Get a list of JAX CUDA requirements for use in setup.py without extra package registries."""
52+
pinned_packages = get_pinned_packages()
53+
if not pinned_packages.use_nightly:
54+
# Stable versions of JAX can be directly installed from PyPI.
55+
return [
56+
f'jaxlib=={pinned_packages.jaxlib_version}',
57+
f'jax=={pinned_packages.jax_version}',
58+
f'jax[cuda12]=={pinned_packages.jax_version}',
59+
]
60+
61+
# Install nightly JAX libraries from the JAX package registries.
62+
jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{pinned_packages.jax_version}-py3-none-any.whl'
63+
jaxlib = []
64+
for python_minor_version in [9, 10, 11]:
65+
jaxlib.append(
66+
f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
67+
)
68+
69+
# Install nightly JAX CUDA libraries.
70+
jax_cuda = [
71+
f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl'
72+
]
73+
for python_minor_version in [9, 10, 11]:
74+
jax_cuda.append(
75+
f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
76+
)
77+
78+
return [jax] + jaxlib + jax_cuda
79+
80+
81+
@dataclass(eq=True, frozen=True)
82+
class PinnedPackages:
83+
use_nightly: bool
84+
"""Whether to use nightly or stable libtpu and JAX"""
85+
86+
date: str
87+
raw_libtpu_version: str
88+
raw_jax_version: str
89+
raw_jaxlib_version: str
90+
91+
@property
92+
def libtpu_version(self) -> str:
93+
if self.use_nightly:
94+
return f'{self.raw_libtpu_version}.dev{self.date}'
95+
else:
96+
return self.raw_libtpu_version
97+
98+
@property
99+
def jax_version(self) -> str:
100+
if self.use_nightly:
101+
return f'{self.raw_jax_version}.dev{self.date}'
102+
else:
103+
return self.raw_jax_version
104+
105+
@property
106+
def jaxlib_version(self) -> str:
107+
if self.use_nightly:
108+
return f'{self.raw_jaxlib_version}.dev{self.date}'
109+
else:
110+
return self.raw_jaxlib_version
111+
112+
@property
113+
def libtpu_storage_directory(self) -> str:
114+
if self.use_nightly:
115+
return 'libtpu-nightly-releases'
116+
else:
117+
return 'libtpu-lts-releases'
118+
119+
@property
120+
def libtpu_wheel_name(self) -> str:
121+
if self.use_nightly:
122+
return f'libtpu-{self.libtpu_version}+nightly'
123+
else:
124+
return f'libtpu-{self.libtpu_version}'
125+
126+
@property
127+
def libtpu_storage_path(self) -> str:
128+
return f'https://storage.googleapis.com/{self.libtpu_storage_directory}/wheels/libtpu/{self.libtpu_wheel_name}-py3-none-linux_x86_64.whl'
6129

7130

8131
def check_env_flag(name: str, default: str = '') -> bool:
@@ -60,7 +183,7 @@ def bazel_build(bazel_target: str,
60183
]
61184

62185
# Remove duplicated flags because they confuse bazel
63-
flags = set(bazel_options_from_env() + options)
186+
flags = set(list(bazel_options_from_env()) + list(options))
64187
bazel_argv.extend(flags)
65188

66189
print(' '.join(bazel_argv), flush=True)

plugins/cuda/setup.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import datetime
21
import os
32
import sys
43

@@ -12,6 +11,6 @@
1211
'torch_xla_cuda_plugin/lib', ['--config=cuda'])
1312

1413
setuptools.setup(
15-
# TODO: Use a common version file
16-
version=os.getenv('TORCH_XLA_VERSION',
17-
f'2.5.0.dev{datetime.date.today().strftime("%Y%m%d")}'))
14+
version=build_util.get_build_version(),
15+
install_requires=build_util.get_jax_cuda_requirements(),
16+
)

setup.py

+10-49
Original file line numberDiff line numberDiff line change
@@ -55,33 +55,14 @@
5555
import os
5656
import requests
5757
import shutil
58-
import subprocess
5958
import sys
6059
import tempfile
6160
import zipfile
6261

6362
import build_util
6463

6564
base_dir = os.path.dirname(os.path.abspath(__file__))
66-
67-
USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
68-
69-
_date = '20250320'
70-
_libtpu_version = '0.0.12'
71-
_jax_version = '0.5.4'
72-
_jaxlib_version = '0.5.4'
73-
74-
_libtpu_wheel_name = f'libtpu-{_libtpu_version}'
75-
_libtpu_storage_directory = 'libtpu-lts-releases'
76-
77-
if USE_NIGHTLY:
78-
_libtpu_version += f".dev{_date}"
79-
_jax_version += f".dev{_date}"
80-
_jaxlib_version += f".dev{_date}"
81-
_libtpu_wheel_name += f".dev{_date}+nightly"
82-
_libtpu_storage_directory = 'libtpu-nightly-releases'
83-
84-
_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}-py3-none-linux_x86_64.whl'
65+
pinned_packages = build_util.get_pinned_packages()
8566

8667

8768
def _get_build_mode():
@@ -90,29 +71,6 @@ def _get_build_mode():
9071
return sys.argv[i]
9172

9273

93-
def get_git_head_sha(base_dir):
94-
xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
95-
cwd=base_dir).decode('ascii').strip()
96-
if os.path.isdir(os.path.join(base_dir, '..', '.git')):
97-
torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
98-
cwd=os.path.join(
99-
base_dir,
100-
'..')).decode('ascii').strip()
101-
else:
102-
torch_git_sha = ''
103-
return xla_git_sha, torch_git_sha
104-
105-
106-
def get_build_version(xla_git_sha):
107-
version = os.getenv('TORCH_XLA_VERSION', '2.8.0')
108-
if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'):
109-
try:
110-
version += '+git' + xla_git_sha[:7]
111-
except Exception:
112-
pass
113-
return version
114-
115-
11674
def create_version_files(base_dir, version, xla_git_sha, torch_git_sha):
11775
print('Building torch_xla version: {}'.format(version))
11876
print('XLA Commit ID: {}'.format(xla_git_sha))
@@ -151,7 +109,7 @@ def maybe_bundle_libtpu(base_dir):
151109
print('No installed libtpu found. Downloading...')
152110

153111
with tempfile.NamedTemporaryFile('wb') as whl:
154-
resp = requests.get(_libtpu_storage_path)
112+
resp = requests.get(pinned_packages.libtpu_storage_path)
155113
resp.raise_for_status()
156114

157115
whl.write(resp.content)
@@ -194,8 +152,8 @@ def run(self):
194152
distutils.command.clean.clean.run(self)
195153

196154

197-
xla_git_sha, torch_git_sha = get_git_head_sha(base_dir)
198-
version = get_build_version(xla_git_sha)
155+
xla_git_sha, torch_git_sha = build_util.get_git_head_sha(base_dir)
156+
version = build_util.get_build_version()
199157

200158
build_mode = _get_build_mode()
201159
if build_mode not in ['clean']:
@@ -226,7 +184,7 @@ class BuildBazelExtension(build_ext.build_ext):
226184
def run(self):
227185
for ext in self.extensions:
228186
self.bazel_build(ext)
229-
command.build_ext.build_ext.run(self)
187+
command.build_ext.build_ext.run(self) # type: ignore
230188

231189
def bazel_build(self, ext):
232190
if not os.path.exists(self.build_temp):
@@ -328,11 +286,14 @@ def run(self):
328286
# On Cloud TPU VM install with:
329287
# pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html
330288
'tpu': [
331-
f'libtpu=={_libtpu_version}',
289+
f'libtpu=={pinned_packages.libtpu_version}',
332290
'tpu-info',
333291
],
334292
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
335-
'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'],
293+
'pallas': [
294+
f'jaxlib=={pinned_packages.jaxlib_version}',
295+
f'jax=={pinned_packages.jax_version}'
296+
],
336297
},
337298
cmdclass={
338299
'build_ext': BuildBazelExtension,

0 commit comments

Comments
 (0)