Skip to content

Commit 00d5f39

Browse files
authored
Llama 3.1 405B 2 pod support (#117)
* Llama 3.1 405B 2 pod support This PR adds 2 pod support for Llama 3.1 405B. Because the best performing Llama configuration does FSDP over DCN links, we need to re-adjust the virtual device mesh to support FSDP over DCN. I've extended the mesh to a pair of two dicts, `mesh` and `dcn_mesh`, each of which may contain `data`, `fsdp`, etc. To get better performance, we also need a very specific device ID arrangement in the virtual device mesh. Specifically, we need each group of 4 TPUs used in tensor parallelism to form a ring. On a 16x16 Trillium pod, the device IDs will look like: ``` 0, 1, 17, 16, 2, 3, 19, 18, 4, 5, 21, 20, ... ``` and so on. Note that the group `0, 1, 17, 16` will end up forming a ring because the TPU with ID 17 is in practice physically located at the first column and the second row. This replicates the MaxText feature at [1]. Tested: ``` export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_decompose_all_gather_einsum=true --xla_tpu_decompose_einsum_reduce_scatter=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true' tp run torchprime/torch_xla_models/train.py model=llama-3.1-405b global_batch_size=128 dcn_mesh.fsdp=2 mesh.fsdp=64 mesh.tensor=4 dataset_config_name=wikitext-103-raw-v1 profile_step=15 profile_duration=200000 max_steps=50 logging_steps=10 ``` The step time is 30.933s (see [2]), which is better than the 31.822s in the Hugging Face fork of https://github.com/pytorch-tpu/transformers/tree/flash_attention_405b. Fixes #112. [1]: https://github.com/AI-Hypercomputer/maxtext/pull/972/files [2]: http://shortn/_Yy2QuIpkus * Rename mesh to ici_mesh
1 parent ed323d6 commit 00d5f39

File tree

16 files changed

+647
-139
lines changed

16 files changed

+647
-139
lines changed

.github/workflows/e2e_test.yml

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ jobs:
1616
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
1717
outputs:
1818
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
19+
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
1920
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
2021
artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }}
2122
steps:
@@ -55,7 +56,28 @@ jobs:
5556
torchprime/torch_xla_models/train.py \
5657
model=llama-3-8b \
5758
global_batch_size=8 \
58-
mesh.fsdp=4 \
59+
ici_mesh.fsdp=4 \
60+
dataset_config_name=wikitext-2-raw-v1 \
61+
profile_step=3 \
62+
max_steps=15
63+
64+
- name: Run Llama 3.0 8B (2D sharding)
65+
id: run-llama-3-8b-2d
66+
env:
67+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
68+
XLA_IR_DEBUG: 1
69+
XLA_HLO_DEBUG: 1
70+
run: |
71+
name=$(e2e_testing/gen_name.py llama-3-8b-2d)
72+
echo "name=$name" >> "$GITHUB_OUTPUT"
73+
tp run \
74+
--name $name \
75+
torchprime/torch_xla_models/train.py \
76+
model=llama-3-8b \
77+
model/scaling=llama-fsdp-tp \
78+
global_batch_size=8 \
79+
ici_mesh.fsdp=2 \
80+
ici_mesh.tensor=2 \
5981
dataset_config_name=wikitext-2-raw-v1 \
6082
profile_step=3 \
6183
max_steps=15
@@ -75,7 +97,7 @@ jobs:
7597
model=mixtral-8x7b \
7698
model.num_hidden_layers=16 \
7799
global_batch_size=8 \
78-
mesh.fsdp=4 \
100+
ici_mesh.fsdp=4 \
79101
dataset_config_name=wikitext-2-raw-v1 \
80102
profile_step=3 \
81103
max_steps=15
@@ -89,6 +111,15 @@ jobs:
89111
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
90112
secrets: inherit
91113

114+
llama-3-8b-2d:
115+
name: Llama 3.0 8B (2D sharding)
116+
needs: tp-run
117+
uses: ./.github/workflows/reusable_e2e_check.yml
118+
with:
119+
jobset_name: ${{ needs.tp-run.outputs.llama-3-8b-2d-name }}
120+
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
121+
secrets: inherit
122+
92123
mixtral-8x7b:
93124
name: Mixtral 8x7B
94125
needs: tp-run

README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ In both `torch_xla_models` and `torchax_models` directories, you'll find
4848
a `configs/default.yaml`. That specifies the default configuration for the
4949
trainer. You may override configs on the command line with a `key=value`
5050
syntax. For example, the following command will train Mixtral 8x7B with a
51-
global batch size of 256, and set the FSDP SPMD mesh axis length to 64:
51+
global batch size of 256, and set the FSDP SPMD ICI mesh axis length to 64:
5252

5353
```sh
5454
python3 torchprime/torch_xla_models/train.py \
5555
model=mixtral-8x7b \
5656
global_batch_size=256 \
57-
mesh.fsdp=64
57+
ici_mesh.fsdp=64
5858
```
5959

6060
You may refer to the hydra docs for other ways to specify configs.
@@ -81,11 +81,23 @@ tp use \
8181
Then prepend `tp run` to a particular Python file you would like to
8282
run remotely, including arguments, e.g.
8383

84+
`torch_xla` example:
85+
86+
```sh
87+
# Train Llama 3.0 8B on 256 chips
88+
tp run torchprime/torch_xla_models/train.py \
89+
model=llama-3-8b \
90+
global_batch_size=256 \
91+
ici_mesh.fsdp=256
92+
```
93+
94+
`torchax` example:
95+
8496
```sh
8597
tp run torchprime/experimental/torchax_models/run.py global_batch_size=256
8698
```
8799

88-
`tp run` will broadcast this command to all VMs in the XPK cluster,
100+
`tp run` will broadcast the specified command to all VMs in the XPK cluster,
89101
which is the convention for running SPMD distributed workloads.
90102

91103
#### Env var passed to the workload

torchprime/experimental/torchax_models/custom_mesh.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

torchprime/experimental/torchax_models/run.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import functools
22
import math
33

4-
import custom_mesh
54
import hydra
65
import jax
76
import numpy as np
@@ -19,6 +18,8 @@
1918
from omegaconf import DictConfig, OmegaConf
2019
from torchax import interop
2120

21+
from torchprime.mesh import custom_mesh
22+
2223
sharding_map_original = {
2324
"freqs_cis": (), # torch.complex64 (2048, 64)
2425
"tok_embeddings.weight": (
@@ -208,7 +209,7 @@ def main(config: DictConfig):
208209
tp = 4
209210
if len(jax.devices()) == 512:
210211
dev_array = custom_mesh.create_custom_64x4_device_mesh(
211-
(64, 4), (2, 1), jax.devices()
212+
(64, tp), (2, 1), jax.devices()
212213
)
213214
else:
214215
assert len(jax.devices()) == 256

torchprime/launcher/thunk.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
[
4444
os.getenv("XLA_FLAGS", ""),
4545
f"--xla_dump_to={xla_dump_path}/",
46-
"--xla_dump_hlo_as_proto",
46+
"--xla_dump_hlo_as_proto", # Save HLO protobuf files
47+
"--xla_dump_hlo_as_text", # Save HLO text files
4748
]
4849
)
4950
print(f"Dumping XLA compiler outputs to {xla_dump_path}", flush=True)

torchprime/mesh/__init__.py

Whitespace-only changes.

torchprime/mesh/custom_mesh.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
`custom_mesh` implements virtual device meshes with better performance than
3+
the default device mesh generated by torch_xla or torchax.
4+
"""
5+
6+
import collections
7+
import dataclasses
8+
from collections.abc import Sequence
9+
from typing import Any
10+
11+
import numpy as np
12+
from torch.utils._pytree import tree_map
13+
14+
15+
def maybe_get_custom_mesh(
16+
ici_mesh_shape: Sequence[int],
17+
dcn_mesh_shape: Sequence[int],
18+
num_devices: int,
19+
num_slices: int,
20+
) -> np.ndarray | None:
21+
"""
22+
Get a more performant custom mesh given the mesh shape if applicable.
23+
24+
The dimensions in mesh shapes should be ordered from least communication intensive
25+
to most communication intensive.
26+
"""
27+
non_trivial_ici_mesh_shape = list(ici_mesh_shape)
28+
while non_trivial_ici_mesh_shape:
29+
if non_trivial_ici_mesh_shape[-1] == 1:
30+
non_trivial_ici_mesh_shape.pop()
31+
else:
32+
break
33+
34+
# Pattern matching for 64x4 custom mesh inside a granule.
35+
# When there exists a 4 chip group that is more communication intensive
36+
# (e.g. tensor parallelism), we should reshape those groups of 4 devices
37+
# into a ring to improve collectives performance.
38+
if (
39+
len(non_trivial_ici_mesh_shape) >= 2
40+
and non_trivial_ici_mesh_shape[-1] == 4
41+
and non_trivial_ici_mesh_shape[-2] == 64
42+
):
43+
return get_64x4_hybrid_ring_mesh(
44+
ici_mesh_shape=non_trivial_ici_mesh_shape,
45+
dcn_mesh_shape=dcn_mesh_shape,
46+
num_devices=num_devices,
47+
num_slices=num_slices,
48+
)
49+
return None
50+
51+
52+
def create_custom_64x4_device_mesh(
53+
mesh_shape: Sequence[int],
54+
dcn_mesh_shape: Sequence[int],
55+
devices: Sequence[Any],
56+
) -> np.ndarray:
57+
"""
58+
Custom device mesh for 64x4 ICI parallelism.
59+
60+
Arranges every group of 4 devices into a ring, to improve collectives performance for those groups
61+
of 4 devices.
62+
63+
This function is a simplified variation of [1].
64+
65+
[1]: https://github.com/jax-ml/jax/blame/1079dc4477d41fd25397c8d0b78a32bdc5fa48da/jax/_src/mesh_utils.py#L790
66+
"""
67+
68+
from jax.experimental import mesh_utils
69+
70+
assert (
71+
len(devices) % 256 == 0
72+
), f"This custom mesh is not valid for {len(devices)} devices"
73+
attr = "slice_index"
74+
if not hasattr(devices[0], attr):
75+
raise ValueError(
76+
f"Device {devices[0]} does not have attribute {attr}. See"
77+
" `process_is_granule` option."
78+
)
79+
granule_dict = collections.defaultdict(list)
80+
for dev in devices:
81+
granule_dict[getattr(dev, attr)].append(dev)
82+
granules = [granule_dict[key] for key in sorted(granule_dict.keys())]
83+
if np.prod(dcn_mesh_shape) != len(granules):
84+
raise ValueError(
85+
f"Number of slices {len(granules)} must equal the product of "
86+
f"dcn_mesh_shape {dcn_mesh_shape}"
87+
)
88+
per_granule_meshes = [
89+
mesh_utils.create_device_mesh(
90+
[16, 16],
91+
granule,
92+
allow_split_physical_axes=False,
93+
)
94+
for granule in granules
95+
]
96+
97+
def reshape_mesh_to_rings(a):
98+
b = []
99+
for i in range(8):
100+
b.append([])
101+
for j in range(8):
102+
a_i = i * 2
103+
a_j = j * 2
104+
# forms a ring of size 4
105+
b[i].append(
106+
[
107+
a[a_i, a_j],
108+
a[a_i, a_j + 1],
109+
a[a_i + 1, a_j + 1],
110+
a[a_i + 1, a_j],
111+
]
112+
)
113+
b = np.array(b)
114+
b = np.reshape(b, (64, 4))
115+
return b
116+
117+
per_granule_meshes = [
118+
np.reshape(reshape_mesh_to_rings(x), mesh_shape) for x in per_granule_meshes
119+
]
120+
granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
121+
blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh)
122+
device_mesh = np.block(blocks.tolist())
123+
return device_mesh
124+
125+
126+
@dataclasses.dataclass
127+
class Device:
128+
process_index: int
129+
slice_index: int
130+
uid: int
131+
device_kind: str = ""
132+
platform: str = "cpu"
133+
134+
135+
def get_64x4_hybrid_ring_mesh(
136+
ici_mesh_shape: Sequence[int],
137+
dcn_mesh_shape: Sequence[int],
138+
num_devices: int,
139+
num_slices: int,
140+
) -> np.ndarray:
141+
num_devices_per_granule = num_devices // num_slices
142+
devices = [
143+
Device(i // num_devices_per_granule, i // num_devices_per_granule, i)
144+
for i in range(num_devices)
145+
]
146+
devices = (
147+
create_custom_64x4_device_mesh(ici_mesh_shape, dcn_mesh_shape, devices)
148+
.reshape(-1)
149+
.tolist()
150+
)
151+
devices = np.array(tree_map(lambda d: d.uid, devices))
152+
return devices

0 commit comments

Comments
 (0)