Skip to content

Commit ba96747

Browse files
committed
add TORCH_VERSION and TODO
1 parent 2105bc6 commit ba96747

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tpu/Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
ARG BASE_IMAGE_TAG
22
ARG LIBTPU_IMAGE_TAG
33
ARG TENSORFLOW_VERSION
4+
ARG TORCH_VERSION
45

56
FROM gcr.io/cloud-tpu-v2-images/libtpu:${LIBTPU_IMAGE_TAG} as libtpu
67
FROM gcr.io/kaggle-images/python-tpu-tensorflow-whl:python-${BASE_IMAGE_TAG}-${TENSORFLOW_VERSION} AS tensorflow_whl
@@ -17,8 +18,8 @@ RUN pip install /tmp/tensorflow_pkg/tensorflow*.whl && \
1718

1819
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
1920
RUN pip uninstall -y torch && \
20-
pip install torch==1.10 && \
21-
pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl && \
21+
pip install torch==${TORCH_VERSION} && \
22+
pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION}-cp37-cp37m-linux_x86_64.whl && \
2223
/tmp/clean-layer.sh
2324

2425
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm

tpu/config.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# TODO(b/213335159): Use ci-pretest for BASE_IMAGE_TAG once stable.
12
BASE_IMAGE_TAG=v108
23
LIBTPU_IMAGE_TAG=libtpu_1.1.0_RC00
3-
TENSORFLOW_VERSION=2.8.0
4+
TENSORFLOW_VERSION=2.8.0
5+
TORCH_VERSION=1.11.0

0 commit comments

Comments
 (0)