File tree Expand file tree Collapse file tree 2 files changed +15
-3
lines changed Expand file tree Collapse file tree 2 files changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -4,6 +4,10 @@ FROM python:3.8
4
4
# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
5
5
ARG TORCH_VERSION
6
6
ARG TENSORFLOW_VERSION
7
+ ARG JAX_VERSION
8
+ ARG TORCHVISION_VERSION
9
+ ARG TORCHTEXT_VERSION
10
+ ARG TORCHAUDIO_VERSION
7
11
8
12
ENV ISTPUVM=1
9
13
@@ -27,14 +31,14 @@ ENV JAX_LIBTPU=/lib/jax-libtpu.so
27
31
28
32
# Install JAX & related packages
29
33
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
30
- RUN pip install " jax[tpu]==0.3.10" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax
34
+ RUN pip install jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax
31
35
32
36
RUN cp $PIP_LIBTPU $JAX_LIBTPU
33
37
34
38
# Install Pytorch & related packages
35
39
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
36
40
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
37
- RUN pip install torch==${TORCH_VERSION} torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==0.12.0 torchtext==0.12.0 torchaudio==0.11.0
41
+ RUN pip install torch==${TORCH_VERSION} torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION}
38
42
39
43
RUN cp $PIP_LIBTPU $PYTORCH_LIBTPU
40
44
Original file line number Diff line number Diff line change 1
1
TENSORFLOW_VERSION=2.9.1
2
- TORCH_VERSION=1.11.0
2
+ JAX_VERSION=0.3.10
3
+ # Supports nightly
4
+ TORCH_VERSION=1.13.0
5
+ # https://github.com/pytorch/audio supports nightly
6
+ TORCHAUDIO_VERSION=0.13.0
7
+ # https://github.com/pytorch/text supports main
8
+ TORCHTEXT_VERSION=0.14.0
9
+ # https://github.com/pytorch/vision supports nightly
10
+ TORCHVISION_VERSION=0.14.0
You can’t perform that action at this time.
0 commit comments