Skip to content

Commit 16b4f69

Browse files
Jieying Luojax authors
authored andcommitted
Rename arg in build script to be more clear.
The flag means skips GPU plugin extension in jaxlib. PiperOrigin-RevId: 627203738
1 parent 47c9495 commit 16b4f69

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

build/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def main():
674674
f"--jaxlib_git_hash={get_githash()}",
675675
f"--cpu={wheel_cpu}"])
676676
if args.build_gpu_plugin:
677-
command.append("--include_gpu_plugin_extension")
677+
command.append("--skip_gpu_kernels")
678678
if args.editable:
679679
command += ["--editable"]
680680
print(" ".join(command))

jaxlib/tools/build_wheel.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@
5757
help="Create an 'editable' jaxlib build instead of a wheel.",
5858
)
5959
parser.add_argument(
60-
"--include_gpu_plugin_extension",
61-
# args.include_gpu_plugin_extension is True when
62-
# --include_gpu_plugin_extension is in the command
60+
"--skip_gpu_kernels",
61+
# args.skip_gpu_kernels is True when
62+
# --skip_gpu_kernels is in the command
6363
action="store_true",
64-
help="Whether to include gpu plugin extension.",
64+
help="Whether to skip gpu kernels in jaxlib.",
6565
)
6666
args = parser.parse_args()
6767

@@ -169,7 +169,7 @@ def write_setup_cfg(sources_path, cpu):
169169
)
170170

171171

172-
def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extension):
172+
def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
173173
"""Assembles a source tree for the wheel in `sources_path`."""
174174
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
175175

@@ -222,7 +222,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
222222
],
223223
)
224224

225-
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not include_gpu_plugin_extension:
225+
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels:
226226
copy_runfiles(
227227
dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
228228
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
@@ -413,7 +413,7 @@ def if_has_mosaic_gpu(extras):
413413
prepare_wheel(
414414
pathlib.Path(sources_path),
415415
cpu=args.cpu,
416-
include_gpu_plugin_extension=args.include_gpu_plugin_extension,
416+
skip_gpu_kernels=args.skip_gpu_kernels,
417417
)
418418
package_name = "jaxlib"
419419
if args.editable:

0 commit comments

Comments
 (0)