Skip to content

Commit ed323d6

Browse files
authored
Pin torch_xla base image to 20250217 (#113)
* Pin torch_xla base image to 20250217 * Update Dockerfile * Update Dockerfile * Add the workaround flag in thunk
1 parent ccff555 commit ed323d6

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

.github/workflows/cpu_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
matrix:
1616
python-version: ["3.10", "3.11"]
1717
container:
18-
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_${{ matrix.python-version }}_tpuvm
18+
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_${{ matrix.python-version }}_tpuvm_20250217
1919
steps:
2020
- uses: actions/checkout@v4
2121
- name: Install torchax

torchprime/launcher/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# syntax=docker/dockerfile:experimental
22
# Use torch_xla Python 3.10 as the base image
3-
# TODO(https://github.com/pytorch/xla/issues/8683): Go back to nightly once the linked segfault is fixed.
4-
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11
3+
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20250217
54

65
ARG USE_TRANSFORMERS=false
76
# Install system dependencies

torchprime/launcher/thunk.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
from datetime import datetime
55
from pathlib import Path
66

7+
# Workaround for MegaScale crash
8+
#
9+
# TODO(https://github.com/pytorch/xla/issues/8683): Remove the
10+
# `--megascale_grpc_enable_xor_tracer=false` flag when libtpu is updated
11+
xla_flags = os.environ.get("LIBTPU_INIT_ARGS", "")
12+
xla_flags = f"{xla_flags} --megascale_grpc_enable_xor_tracer=false"
13+
os.environ["LIBTPU_INIT_ARGS"] = xla_flags
14+
715
# Get the artifact dir from env var.
816
gcs_artifact_dir = os.environ["TORCHPRIME_ARTIFACT_DIR"]
917
assert gcs_artifact_dir.startswith(

0 commit comments

Comments
 (0)