Skip to content

Commit 3c59038

Browse files
authored
Add a multi-slice E2E test and update torch_xla pin to 20250313 (#146)
* Add a multi-slice E2E test and update torch_xla to 20250313 20250313 docker contains the fixes for multi-slice training and we add an E2E test to make sure this doesn't regress. * Add more help
1 parent f15f22a commit 3c59038

File tree

5 files changed

+52
-5
lines changed

5 files changed

+52
-5
lines changed

.github/workflows/cpu_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
matrix:
1616
python-version: ["3.10", "3.11"]
1717
container:
18-
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_${{ matrix.python-version }}_tpuvm_cxx11_20250312
18+
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_${{ matrix.python-version }}_tpuvm_cxx11_20250313
1919
steps:
2020
- uses: actions/checkout@v4
2121
- name: Install torchax

.github/workflows/e2e_test.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jobs:
1717
outputs:
1818
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
1919
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
20+
llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }}
2021
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
2122
artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }}
2223
steps:
@@ -102,6 +103,28 @@ jobs:
102103
profile_step=3 \
103104
max_steps=15
104105
106+
- name: Run Llama 3.0 8B (2 slice)
107+
id: run-llama-3-8b-2-slice
108+
env:
109+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
110+
XLA_IR_DEBUG: 1
111+
XLA_HLO_DEBUG: 1
112+
run: |
113+
name=$(e2e_testing/gen_name.py llama-3-8b-2-slice)
114+
echo "name=$name" >> "$GITHUB_OUTPUT"
115+
tp run \
116+
--name $name \
117+
--num-slices 2 \
118+
torchprime/torch_xla_models/train.py \
119+
model=llama-3-8b \
120+
model/scaling=llama-fsdp \
121+
global_batch_size=16 \
122+
dcn_mesh.fsdp=2 \
123+
ici_mesh.fsdp=4 \
124+
dataset_config_name=wikitext-2-raw-v1 \
125+
profile_step=3 \
126+
max_steps=15
127+
105128
llama-3-8b:
106129
name: Llama 3.0 8B
107130
needs: tp-run
@@ -120,6 +143,15 @@ jobs:
120143
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
121144
secrets: inherit
122145

146+
llama-3-8b-2-slice:
147+
name: Llama 3.0 8B (2 slice)
148+
needs: tp-run
149+
uses: ./.github/workflows/reusable_e2e_check.yml
150+
with:
151+
jobset_name: ${{ needs.tp-run.outputs.llama-3-8b-2-slice-name }}
152+
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
153+
secrets: inherit
154+
123155
mixtral-8x7b:
124156
name: Mixtral 8x7B
125157
needs: tp-run

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ tp run torchprime/experimental/torchax_models/run.py global_batch_size=256
9999

100100
`tp run` will broadcast the specified command to all VMs in the XPK cluster,
101101
which is the convention for running SPMD distributed workloads.
102+
See `tp run --help` for more advanced features.
102103

103104
#### Env vars passed to the workload
104105

torchprime/launcher/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# syntax=docker/dockerfile:experimental
22
# Use torch_xla Python 3.10 as the base image
3-
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_cxx11_20250312
3+
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_cxx11_20250313
44

55
ARG USE_TRANSFORMERS=false
66
ARG USE_LOCAL_WHEEL=false

torchprime/launcher/cli.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ def cli(ctx, interactive):
7070
@click.option(
7171
"--num-slices",
7272
required=False,
73+
type=int,
7374
default=1,
74-
help="Number of TPU slice to use. Defaults to 1",
75+
help="Number of TPU slice to use by default. Defaults to 1",
7576
)
7677
@click.option(
7778
"--tpu-type",
@@ -207,14 +208,24 @@ def create_and_activate_gcloud(gcloud_config_name, config: Config):
207208
"defaults to one based on the date and time.",
208209
default=None,
209210
)
211+
@click.option(
212+
"--num-slices",
213+
required=False,
214+
type=int,
215+
default=None,
216+
help="Temporarily override the number of TPU slice to use for this run. "
217+
"If unspecified, `tp run` will use the slice count configured in `tp use`.",
218+
)
210219
@click.option("--use-hf", is_flag=True, help="Use HuggingFace transformer")
211220
@click.option(
212221
"--use-local-wheel",
213222
is_flag=True,
214223
help="Use local torch and torch_xla wheels under folder local_dist/",
215224
)
216225
@interactive
217-
def run(args, name: str | None, use_hf: bool, use_local_wheel: bool):
226+
def run(
227+
args, name: str | None, num_slices: int | None, use_hf: bool, use_local_wheel: bool
228+
):
218229
"""
219230
Runs the provided SPMD training command as an xpk job on a GKE cluster.
220231
"""
@@ -258,6 +269,9 @@ def run(args, name: str | None, use_hf: bool, use_local_wheel: bool):
258269
f"TORCHPRIME_JOBSET_NAME={workload_name}",
259270
]
260271

272+
if num_slices is None:
273+
num_slices = config.num_slices
274+
261275
ensure_command("xpk")
262276
xpk_command = (
263277
[
@@ -273,7 +287,7 @@ def run(args, name: str | None, use_hf: bool, use_local_wheel: bool):
273287
"--tpu-type",
274288
config.tpu_type,
275289
"--num-slices",
276-
str(config.num_slices),
290+
str(num_slices),
277291
"--zone",
278292
config.zone,
279293
"--project",

0 commit comments

Comments
 (0)