Skip to content

Remove dependency on JAX #9494

@jeffhataws

Description

@jeffhataws

🐛 Bug

In v2.8, the dependency on JAX is added as part of pip dependency in https://github.com/pytorch/xla/blob/master/setup.py#L121. This causes JAX to be installed in user environment even if they are not using torchax or JAX. Additionally we now see the error #9243 .

Also, current_accelerator() now returns device(type='jax'), instead of device(type='cuda') causing unexpected change in some behavior such as in parallel loader where pin_memory now must be set to False to work as before.

Finally, due to this dependency, torch-xla now also drops python 3.10 support even though it is supported until Oct 2026. It's best to bring back this support for customers who still use Ubuntu 22.

To improve customer experience, please make this dependency optional and maybe only a dependency of torchax instead.

To Reproduce

Install torch-xla 2.8 release candidate, and do pipdeptree to get the dependency list, and observe that jax is part of the torch-xla dependency tree.

Expected behavior

JAX is not installed by default with torch-xla

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: any
  • torch_xla version: 2.8

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    installPyTorch/XLA installation related issues.usabilityBugs/features related to improving the usability of PyTorch/XLA

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions