File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change 1
1
ARG BASE_IMAGE_TAG
2
2
ARG LIBTPU_IMAGE_TAG
3
3
ARG TENSORFLOW_VERSION
4
+ ARG TORCH_VERSION
4
5
5
6
FROM gcr.io/cloud-tpu-v2-images/libtpu:${LIBTPU_IMAGE_TAG} as libtpu
6
7
FROM gcr.io/kaggle-images/python-tpu-tensorflow-whl:python-${BASE_IMAGE_TAG}-${TENSORFLOW_VERSION} AS tensorflow_whl
@@ -17,8 +18,8 @@ RUN pip install /tmp/tensorflow_pkg/tensorflow*.whl && \
17
18
18
19
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
19
20
RUN pip uninstall -y torch && \
20
- pip install torch==1.10 && \
21
- pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10 -cp37-cp37m-linux_x86_64.whl && \
21
+ pip install torch==${TORCH_VERSION} && \
22
+ pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION} -cp37-cp37m-linux_x86_64.whl && \
22
23
/tmp/clean-layer.sh
23
24
24
25
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
Original file line number Diff line number Diff line change
1
+ # TODO(b/213335159): Use ci-pretest for BASE_IMAGE_TAG once stable.
1
2
BASE_IMAGE_TAG=v108
2
3
LIBTPU_IMAGE_TAG=libtpu_1.1.0_RC00
3
- TENSORFLOW_VERSION=2.8.0
4
+ TENSORFLOW_VERSION=2.8.0
5
+ TORCH_VERSION=1.11.0
You can’t perform that action at this time.
0 commit comments