|
57 | 57 | help="Create an 'editable' jaxlib build instead of a wheel.",
|
58 | 58 | )
|
59 | 59 | parser.add_argument(
|
60 |
| - "--include_gpu_plugin_extension", |
61 |
| - # args.include_gpu_plugin_extension is True when |
62 |
| - # --include_gpu_plugin_extension is in the command |
| 60 | + "--skip_gpu_kernels", |
| 61 | + # args.skip_gpu_kernels is True when |
| 62 | + # --skip_gpu_kernels is in the command |
63 | 63 | action="store_true",
|
64 |
| - help="Whether to include gpu plugin extension.", |
| 64 | + help="Whether to skip gpu kernels in jaxlib.", |
65 | 65 | )
|
66 | 66 | args = parser.parse_args()
|
67 | 67 |
|
@@ -169,7 +169,7 @@ def write_setup_cfg(sources_path, cpu):
|
169 | 169 | )
|
170 | 170 |
|
171 | 171 |
|
172 |
| -def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extension): |
| 172 | +def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): |
173 | 173 | """Assembles a source tree for the wheel in `sources_path`."""
|
174 | 174 | copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
|
175 | 175 |
|
@@ -222,7 +222,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
222 | 222 | ],
|
223 | 223 | )
|
224 | 224 |
|
225 |
| - if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not include_gpu_plugin_extension: |
| 225 | + if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels: |
226 | 226 | copy_runfiles(
|
227 | 227 | dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
|
228 | 228 | src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
|
@@ -413,7 +413,7 @@ def if_has_mosaic_gpu(extras):
|
413 | 413 | prepare_wheel(
|
414 | 414 | pathlib.Path(sources_path),
|
415 | 415 | cpu=args.cpu,
|
416 |
| - include_gpu_plugin_extension=args.include_gpu_plugin_extension, |
| 416 | + skip_gpu_kernels=args.skip_gpu_kernels, |
417 | 417 | )
|
418 | 418 | package_name = "jaxlib"
|
419 | 419 | if args.editable:
|
|
0 commit comments