-
Notifications
You must be signed in to change notification settings - Fork 557
Description
🐛 Bug
In v2.8, the dependency on JAX is added as part of pip dependency in https://github.com/pytorch/xla/blob/master/setup.py#L121. This causes JAX to be installed in user environment even if they are not using torchax or JAX. Additionally we now see the error #9243 .
Also, current_accelerator() now returns device(type='jax'), instead of device(type='cuda') causing unexpected change in some behavior such as in parallel loader where pin_memory now must be set to False to work as before.
Finally, due to this dependency, torch-xla now also drops python 3.10 support even though it is supported until Oct 2026. It's best to bring back this support for customers who still use Ubuntu 22.
To improve customer experience, please make this dependency optional and maybe only a dependency of torchax instead.
To Reproduce
Install torch-xla 2.8 release candidate, and do pipdeptree to get the dependency list, and observe that jax is part of the torch-xla dependency tree.
Expected behavior
JAX is not installed by default with torch-xla
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: any
- torch_xla version: 2.8