Skip to content

Commit b458b0e

Browse files
authored
Update installation to mention hardware support (#203)
1 parent 07f58d7 commit b458b0e

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,14 @@ pip install jax-ai-stack[tfds]
5757
will install a compatible version of
5858
[tensorflow](https://github.com/tensorflow/tensorflow)
5959
and [tensorflow-datasets](https://github.com/tensorflow/datasets).
60+
61+
### Hardware support
62+
63+
To install `jax-ai-stack` with hardware-specific JAX support, add the JAX installation
64+
command in the same `pip install` invocation. For example:
65+
```
66+
pip install jax-ai-stack "jax[cuda]" # JAX + AI stack with GPU/CUDA support
67+
pip install jax-ai-stack "jax[tpu]" # JAX + AI stack with TPU support
68+
```
69+
For more information on available options for hardware-specific JAX installation, refer
70+
to [JAX installation](https://docs.jax.dev/en/latest/installation.html).

docs/source/install.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,15 @@ command:
3535
pip install jax-ai-stack==2024.11.1
3636
```
3737
For the full list of released versions and the pinned packages, refer to the [Change log](https://github.com/jax-ml/jax-ai-stack/blob/main/CHANGELOG.md).
38+
39+
40+
## Hardware support
41+
42+
To install `jax-ai-stack` with hardware-specific JAX support, add the JAX installation
43+
command in the same `pip install` invocation. For example:
44+
```
45+
pip install jax-ai-stack "jax[cuda]" # JAX + AI stack with GPU/CUDA support
46+
pip install jax-ai-stack "jax[tpu]" # JAX + AI stack with TPU support
47+
```
48+
For more information on available options for hardware-specific JAX installation, refer
49+
to [JAX installation](https://docs.jax.dev/en/latest/installation.html).

0 commit comments

Comments
 (0)