|
3 | 3 | import subprocess
|
4 | 4 | import sys
|
5 | 5 | import shutil
|
| 6 | +from dataclasses import dataclass |
| 7 | +import functools |
| 8 | + |
| 9 | +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| 10 | + |
| 11 | + |
| 12 | +@functools.lru_cache |
| 13 | +def get_pinned_packages(): |
| 14 | + """Gets the versions of important pinned dependencies of torch_xla.""" |
| 15 | + return PinnedPackages( |
| 16 | + use_nightly=True, |
| 17 | + date='20250320', |
| 18 | + raw_libtpu_version='0.0.12', |
| 19 | + raw_jax_version='0.5.4', |
| 20 | + raw_jaxlib_version='0.5.4', |
| 21 | + ) |
| 22 | + |
| 23 | + |
| 24 | +@functools.lru_cache |
| 25 | +def get_build_version(): |
| 26 | + xla_git_sha, _torch_git_sha = get_git_head_sha(BASE_DIR) |
| 27 | + version = os.getenv('TORCH_XLA_VERSION', '2.8.0') |
| 28 | + if check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): |
| 29 | + try: |
| 30 | + version += '+git' + xla_git_sha[:7] |
| 31 | + except Exception: |
| 32 | + pass |
| 33 | + return version |
| 34 | + |
| 35 | + |
| 36 | +@functools.lru_cache |
| 37 | +def get_git_head_sha(base_dir): |
| 38 | + xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], |
| 39 | + cwd=base_dir).decode('ascii').strip() |
| 40 | + if os.path.isdir(os.path.join(base_dir, '..', '.git')): |
| 41 | + torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], |
| 42 | + cwd=os.path.join( |
| 43 | + base_dir, |
| 44 | + '..')).decode('ascii').strip() |
| 45 | + else: |
| 46 | + torch_git_sha = '' |
| 47 | + return xla_git_sha, torch_git_sha |
| 48 | + |
| 49 | + |
| 50 | +def get_jax_cuda_requirements(): |
| 51 | + """Get a list of JAX CUDA requirements for use in setup.py without extra package registries.""" |
| 52 | + pinned_packages = get_pinned_packages() |
| 53 | + if not pinned_packages.use_nightly: |
| 54 | + # Stable versions of JAX can be directly installed from PyPI. |
| 55 | + return [ |
| 56 | + f'jaxlib=={pinned_packages.jaxlib_version}', |
| 57 | + f'jax=={pinned_packages.jax_version}', |
| 58 | + f'jax[cuda12]=={pinned_packages.jax_version}', |
| 59 | + ] |
| 60 | + |
| 61 | + # Install nightly JAX libraries from the JAX package registries. |
| 62 | + jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{pinned_packages.jax_version}-py3-none-any.whl' |
| 63 | + jaxlib = [] |
| 64 | + for python_minor_version in [9, 10, 11]: |
| 65 | + jaxlib.append( |
| 66 | + f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' |
| 67 | + ) |
| 68 | + |
| 69 | + # Install nightly JAX CUDA libraries. |
| 70 | + jax_cuda = [ |
| 71 | + f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl' |
| 72 | + ] |
| 73 | + for python_minor_version in [9, 10, 11]: |
| 74 | + jax_cuda.append( |
| 75 | + f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' |
| 76 | + ) |
| 77 | + |
| 78 | + return [jax] + jaxlib + jax_cuda |
| 79 | + |
| 80 | + |
| 81 | +@dataclass(eq=True, frozen=True) |
| 82 | +class PinnedPackages: |
| 83 | + use_nightly: bool |
| 84 | + """Whether to use nightly or stable libtpu and JAX""" |
| 85 | + |
| 86 | + date: str |
| 87 | + raw_libtpu_version: str |
| 88 | + raw_jax_version: str |
| 89 | + raw_jaxlib_version: str |
| 90 | + |
| 91 | + @property |
| 92 | + def libtpu_version(self) -> str: |
| 93 | + if self.use_nightly: |
| 94 | + return f'{self.raw_libtpu_version}.dev{self.date}' |
| 95 | + else: |
| 96 | + return self.raw_libtpu_version |
| 97 | + |
| 98 | + @property |
| 99 | + def jax_version(self) -> str: |
| 100 | + if self.use_nightly: |
| 101 | + return f'{self.raw_jax_version}.dev{self.date}' |
| 102 | + else: |
| 103 | + return self.raw_jax_version |
| 104 | + |
| 105 | + @property |
| 106 | + def jaxlib_version(self) -> str: |
| 107 | + if self.use_nightly: |
| 108 | + return f'{self.raw_jaxlib_version}.dev{self.date}' |
| 109 | + else: |
| 110 | + return self.raw_jaxlib_version |
| 111 | + |
| 112 | + @property |
| 113 | + def libtpu_storage_directory(self) -> str: |
| 114 | + if self.use_nightly: |
| 115 | + return 'libtpu-nightly-releases' |
| 116 | + else: |
| 117 | + return 'libtpu-lts-releases' |
| 118 | + |
| 119 | + @property |
| 120 | + def libtpu_wheel_name(self) -> str: |
| 121 | + if self.use_nightly: |
| 122 | + return f'libtpu-{self.libtpu_version}+nightly' |
| 123 | + else: |
| 124 | + return f'libtpu-{self.libtpu_version}' |
| 125 | + |
| 126 | + @property |
| 127 | + def libtpu_storage_path(self) -> str: |
| 128 | + return f'https://storage.googleapis.com/{self.libtpu_storage_directory}/wheels/libtpu/{self.libtpu_wheel_name}-py3-none-linux_x86_64.whl' |
6 | 129 |
|
7 | 130 |
|
8 | 131 | def check_env_flag(name: str, default: str = '') -> bool:
|
@@ -60,7 +183,7 @@ def bazel_build(bazel_target: str,
|
60 | 183 | ]
|
61 | 184 |
|
62 | 185 | # Remove duplicated flags because they confuse bazel
|
63 |
| - flags = set(bazel_options_from_env() + options) |
| 186 | + flags = set(list(bazel_options_from_env()) + list(options)) |
64 | 187 | bazel_argv.extend(flags)
|
65 | 188 |
|
66 | 189 | print(' '.join(bazel_argv), flush=True)
|
|
0 commit comments