Skip to content

Commit ae991f3

Browse files
committed
Upgrade TPU Pytorch to 1.13
1 parent 0f422b7 commit ae991f3

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

tpu/Dockerfile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ FROM python:3.8
44
# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
55
ARG TORCH_VERSION
66
ARG TENSORFLOW_VERSION
7+
ARG JAX_VERSION
8+
ARG TORCHVISION_VERSION
9+
ARG TORCHTEXT_VERSION
10+
ARG TORCHAUDIO_VERSION
711

812
ENV ISTPUVM=1
913

@@ -27,14 +31,14 @@ ENV JAX_LIBTPU=/lib/jax-libtpu.so
2731

2832
# Install JAX & related packages
2933
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
30-
RUN pip install "jax[tpu]==0.3.10" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax
34+
RUN pip install jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax
3135

3236
RUN cp $PIP_LIBTPU $JAX_LIBTPU
3337

3438
# Install Pytorch & related packages
3539
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
3640
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
37-
RUN pip install torch==${TORCH_VERSION} torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==0.12.0 torchtext==0.12.0 torchaudio==0.11.0
41+
RUN pip install torch==${TORCH_VERSION} torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION}
3842

3943
RUN cp $PIP_LIBTPU $PYTORCH_LIBTPU
4044

tpu/config.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,10 @@
11
TENSORFLOW_VERSION=2.9.1
2-
TORCH_VERSION=1.11.0
2+
JAX_VERSION=0.3.10
3+
# Supports nightly
4+
TORCH_VERSION=1.13.0
5+
# https://github.com/pytorch/audio supports nightly
6+
TORCHAUDIO_VERSION=0.13.0
7+
# https://github.com/pytorch/text supports main
8+
TORCHTEXT_VERSION=0.14.0
9+
# https://github.com/pytorch/vision supports nightly
10+
TORCHVISION_VERSION=0.14.0

0 commit comments

Comments
 (0)