diff --git a/README.md b/README.md index 19e8120..6913da0 100644 --- a/README.md +++ b/README.md @@ -57,3 +57,14 @@ pip install jax-ai-stack[tfds] will install a compatible version of [tensorflow](https://github.com/tensorflow/tensorflow) and [tensorflow-datasets](https://github.com/tensorflow/datasets). + +### Hardware support + +To install `jax-ai-stack` with hardware-specific JAX support, add the JAX installation +command in the same `pip install` invocation. For example: +``` +pip install jax-ai-stack "jax[cuda]" # JAX + AI stack with GPU/CUDA support +pip install jax-ai-stack "jax[tpu]" # JAX + AI stack with TPU support +``` +For more information on available options for hardware-specific JAX installation, refer +to [JAX installation](https://docs.jax.dev/en/latest/installation.html). diff --git a/docs/source/install.md b/docs/source/install.md index 8ba5645..5d54719 100644 --- a/docs/source/install.md +++ b/docs/source/install.md @@ -35,3 +35,15 @@ command: pip install jax-ai-stack==2024.11.1 ``` For the full list of released versions and the pinned packages, refer to the [Change log](https://github.com/jax-ml/jax-ai-stack/blob/main/CHANGELOG.md). + + +## Hardware support + +To install `jax-ai-stack` with hardware-specific JAX support, add the JAX installation +command in the same `pip install` invocation. For example: +``` +pip install jax-ai-stack "jax[cuda]" # JAX + AI stack with GPU/CUDA support +pip install jax-ai-stack "jax[tpu]" # JAX + AI stack with TPU support +``` +For more information on available options for hardware-specific JAX installation, refer +to [JAX installation](https://docs.jax.dev/en/latest/installation.html).