Skip to content

Commit 69a1fc1

Browse files
authored
Document model status (#138)
* Document model status * Update * Move hf instruction
1 parent f23d3cf commit 69a1fc1

File tree

4 files changed

+225
-64
lines changed

4 files changed

+225
-64
lines changed

README.md

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,51 @@ tp run torchprime/experimental/torchax_models/run.py global_batch_size=256
100100
`tp run` will broadcast the specified command to all VMs in the XPK cluster,
101101
which is the convention for running SPMD distributed workloads.
102102

103-
#### Env var passed to the workload
103+
#### Env vars passed to the workload
104104

105105
`tp run` will pick up these environment variables locally and proxy them
106106
to the distributed workload, if found:
107107

108108
- `HF_TOKEN`: HuggingFace token
109109
- `XLA_IR_DEBUG`: [torch_xla debugging flag][torch_xla_debug_env]
110110
- `XLA_HLO_DEBUG`: [torch_xla debugging flag][torch_xla_debug_env]
111-
- `LIBTPU_INIT_ARGS`: xla flag
111+
- `LIBTPU_INIT_ARGS`: XLA flags that affect compilation and execution behavior
112+
113+
## Model status
114+
115+
Here are the status of various models. In general, there are five stages for
116+
each model:
117+
118+
- **TODO**: We need to implement the model.
119+
- **Implemented**: The model runs either a training or an inference step.
120+
- **Optimized**: We found the best scaling configuration for the model on one or
121+
more hardware. One-off performance data is available.
122+
- **Convergence**: We tested that the training loss converges to a reasonable
123+
value, or that the loss curve tracks an existing reference if exists.
124+
- **Production**: Not only is the model optimized and converges, its performance
125+
is also continuously monitored. This is a good state for using the model in
126+
production.
127+
128+
All implemented models will at least have unit tests to verify basic numerical
129+
correctness, and the convergence verification stage serves as an additional
130+
correctness guarantee.
131+
132+
If a model is at least implemented, you'll also find a training recipe linked
133+
from the checkmark emoji in the table. If a model is optimized, you'll also find
134+
MFU numbers linked from the table. Note that a model may continue to receive
135+
ongoing optimization thereafter.
136+
137+
| **Model** | **Implemented** | **Optimized** | **Converges** |
138+
| -------------------- | ---------------------------------------------------------------------- | ------------------------------------------------------------------ | ------------- |
139+
| Llama 3.0 8B | [](torchprime/torch_xla_models/README.md#llama-30-8b-on-v6e-256) | [](torchprime/torch_xla_models/README.md#llama-30-8b-on-v6e-256) | [TODO](https://github.com/AI-Hypercomputer/torchprime/issues/90) |
140+
| Llama 3.1 8B | [](torchprime/torch_xla_models/README.md#llama-31-8b-on-v6e-256) | [TODO](https://github.com/AI-Hypercomputer/torchprime/issues/133) | TODO |
141+
| Llama 3.1 70B | [TODO](https://github.com/AI-Hypercomputer/torchprime/issues/17) | TODO | TODO |
142+
| Llama 3.1 405B | [](torchprime/torch_xla_models/README.md#llama-31-405b-on-v6e-256) | [TODO](https://github.com/AI-Hypercomputer/torchprime/milestone/2) | TODO |
143+
| Mixtral 8x7B | [](torchprime/torch_xla_models/README.md#mixtral-8x7b-on-v6e-256) | [TODO](https://github.com/AI-Hypercomputer/torchprime/issues/44) | TODO |
144+
| Mixtral 8x22B | [TODO](https://github.com/AI-Hypercomputer/torchprime/issues/45) | TODO | TODO |
145+
| DeepSeek V3/R1 | TODO | TODO | TODO |
146+
| Stable Diffusion 2.0 | [TODO](https://github.com/AI-Hypercomputer/torchprime/issues/87) | TODO | TODO |
147+
| Stable Diffusion 2.1 | [TODO](https://github.com/AI-Hypercomputer/torchprime/issues/88) | TODO | TODO |
112148

113149
## Structure
114150

@@ -133,31 +169,6 @@ and attributes where this model code came from, if any. This also helps to
133169
show case what changes we have done to make it performant on TPU. The original
134170
version is not expected to be run.
135171

136-
## Run huggingface transformer models
137-
Torchprime supports run with huggingface models by taking advantage of `tp run`.
138-
To use huggingface models, you can clone
139-
[huggingface/transformers](https://github.com/huggingface/transformers) under
140-
torchprime and name it as `local_transformers`. This allows you to pick any
141-
branch or make code modifications in transformers for experiment:
142-
```
143-
git clone https://github.com/huggingface/transformers.git local_transformers
144-
```
145-
If huggingface transformer doesn't exist, torchprime will automatically clone
146-
the repo and build the docker for experiment. To switch to huggingface models,
147-
add flag `--use-hf` to `tp run` command:
148-
```
149-
tp run --use-hf torchprime/hf_models/train.py
150-
```
151-
152-
## Run with local torch/torch_xla wheel
153-
Torchprime supports run with user specified torch and torch_xla wheels placed
154-
under `local_dist/` directory. The wheel will be automatically installed in the
155-
docker image when use `tp run` command. To use the wheel, add flag
156-
`--use-local-wheel` to `tp run` command:
157-
```
158-
tp run --use-local-wheel torchprime/hf_models/train.py
159-
```
160-
161172
## Contributing
162173

163174
Contributions are welcome! Please feel free to submit a pull request.
@@ -192,6 +203,21 @@ ruff check [--fix]
192203
You can install a Ruff VSCode plugin to check errors and format files from
193204
the editor.
194205

206+
## Run distributed training with local torch/torch_xla wheel
207+
208+
Torchprime supports running with user specified torch and torch_xla wheels placed
209+
under `local_dist/` directory. The wheel will be automatically installed in the
210+
docker image when use `tp run` command. To use the wheel, add flag
211+
`--use-local-wheel` to `tp run` command:
212+
213+
```sh
214+
tp run --use-local-wheel torchprime/hf_models/train.py
215+
```
216+
217+
The wheels should be built inside a
218+
[PyTorch/XLA development docker image][torch_xla_dev_docker] or the PyTorch/XLA
219+
VSCode Dev Container to minimize compatibility issues.
220+
195221
## License
196222

197223
This project is licensed under the New BSD License - see the [LICENSE](LICENSE)
@@ -205,3 +231,4 @@ For more information on PyTorch/XLA, visit the
205231
[xpk]: https://github.com/AI-Hypercomputer/xpk
206232
[torch_xla_debug_env]: https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#environment-variables
207233
[hydra]: https://hydra.cc/docs/intro/
234+
[torch_xla_dev_docker]: https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md#manually-build-in-docker-container

torchprime/experimental/torchax_models/README.md

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,40 @@ pip install optax tensorflow tensorboard-plugin-profile
4040
pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
4141
```
4242

43-
## Running locally
43+
## Running locally on a TPU VM
4444

45-
```bash
46-
python run.py --model_impl=<orig|scan|scan_manual>
45+
Setup environment as per [README][README-examples].
46+
47+
```sh
48+
python run.py model_impl=<orig|scan|scan_manual>
4749
```
4850

49-
## Run on XPK
51+
### Llama 3.1 8B on v6e-8
52+
53+
Recipe for global batch size 8, sequence length 8192.
54+
Expected step duration: 1.7s. MFU: 30%.
55+
56+
```sh
57+
export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"
5058

51-
Follow the guide in `tp use` to setup the cluster information.
59+
python run.py model_impl=scan tp=1 global_batch_size=8 seqlen=8192
60+
```
5261

53-
Run `tp run <loal command>` to run the training command on the XPK cluster.
62+
## Running on a XPK cluster
5463

55-
## Benchmarks (WIP)
64+
First follow the [distributed training][distributed-training] guide to setup the
65+
cluster information.
5666

57-
|device| Model size | Batch size | seq length | step time | MFU | NOTEs|
58-
|-------| ----- | ----- | ----- | ----- | ----- | ---|
59-
|TPU v6e-8| 8B | 8 | 8192 | 1.7s | 30% | Scan, fsdp, host-offload|
60-
|TPU v6e-256 x 2| 405B | 256 | 8192 | 46.12s | 28.7% | Scan, fsdp + tp, host-offload|
67+
Run `tp run <local command>` to run the training command on the XPK cluster.
6168

62-
<!-- TODO: support specifying different XLA flags -->
69+
### Llama 3.1 405B on 2 pods of v6e-256
6370

64-
Llama 3.1 405B on v6e-256 x 2 command:
71+
Recipe for global batch size 256, sequence length 8192.
72+
Expected step duration: 46.12s. MFU: 28.7%.
6573

6674
```sh
75+
export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"
76+
6777
tp run torchprime/experimental/torchax_models/run.py \
6878
global_batch_size=256 \
6979
model_type=405B \
@@ -74,3 +84,6 @@ tp run torchprime/experimental/torchax_models/run.py \
7484
tp=4 \
7585
unroll_layers=1
7686
```
87+
88+
[README-examples]: ../../README.md#examples
89+
[distributed-training]: ../../README.md#distributed-training

torchprime/hf_models/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Run huggingface transformer models
2+
3+
For contributors to torchprime, `tp run` also supports running the huggingface
4+
trainer, for debugging and comparison. This module implements an adapter over
5+
the huggingface trainer.
6+
7+
To run the huggingface trainer, you can clone
8+
[huggingface/transformers][hf-transformers] under the root directory of
9+
torchprime and name it as `local_transformers`. This allows you to pick any
10+
branch or make code modifications in transformers for experiment:
11+
12+
```sh
13+
git clone https://github.com/huggingface/transformers.git local_transformers
14+
```
15+
16+
If huggingface transformer doesn't exist, torchprime will automatically clone
17+
the repo and build the docker for experiment. To switch to huggingface models,
18+
add flag `--use-hf` to `tp run` command:
19+
20+
```sh
21+
tp run --use-hf torchprime/hf_models/train.py
22+
```
23+
24+
[hf-transformers]: https://github.com/huggingface/transformers

torchprime/torch_xla_models/README.md

Lines changed: 121 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,147 @@
11
# torch_xla models
22

3-
## Features
3+
These models use the [torch_xla][1] framework.
44

5-
- Optimized for PyTorch/XLA
6-
- Demonstrates GSPMD parallelism
7-
- Supports large language models tasks
5+
## Running locally on a TPU VM
86

9-
## Running locally
7+
1. Setup environment as per [README][README-examples].
108

11-
1. Clone the repository:
9+
1. Export key environment variables:
1210

13-
```
14-
git clone https://github.com/AI-Hypercomputer/torchprime.git
15-
cd torchprime
11+
```sh
12+
export HF_TOKEN='... hugging face token ...'
13+
export XLA_IR_DEBUG=1
14+
export XLA_HLO_DEBUG=1
1615
```
1716

18-
2. Install the package:
17+
1. Run the trainer. The default is to train Llama 3.0 8B sharded over 4 chips.
1918

19+
```sh
20+
python3 torchprime/torch_xla_models/train.py
2021
```
21-
pip install -e .
22-
```
23-
24-
3. Run the training script:
2522

26-
```
27-
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py
28-
```
23+
## Running on a XPK cluster
2924

30-
## Running on XPK
25+
First follow the [distributed training][distributed-training] guide to setup the
26+
cluster information.
3127

32-
Follow the guide in `tp use` to setup the cluster information.
28+
Then export key environment variables in your local environment:
3329

3430
```sh
3531
export HF_TOKEN='... hugging face token ...'
3632
export XLA_IR_DEBUG=1
37-
export XLA_HLO_DEBUG=1
33+
export XLA_HLO_DEBUG=1
34+
```
35+
36+
Finally pick from one of these recipes, and it will build a docker image and
37+
launch it on XPK.
38+
39+
### Llama 3.0 8B on v6e-256
40+
41+
Recipe for global batch size 256, sequence length 8192.
42+
Expected step duration: 1.625s. MFU: 33.53%.
3843

39-
tp run torchprime/torch_xla_models/train.py
44+
```sh
45+
tp run torchprime/torch_xla_models/train.py \
46+
model=llama-3-8b \
47+
global_batch_size=256 \
48+
block_size=8192 \
49+
profile_step=5 \
50+
ici_mesh.fsdp=256
51+
```
52+
53+
Recipe for global batch size 512, sequence length 8192.
54+
Expected step duration: 2.991s. MFU: 36.43%.
55+
56+
```sh
57+
tp run torchprime/torch_xla_models/train.py \
58+
model=llama-3-8b \
59+
global_batch_size=512 \
60+
block_size=8192 \
61+
profile_step=5 \
62+
ici_mesh.fsdp=256
4063
```
4164

42-
This will build the dockerfile and launch it on XPK.
65+
### Llama 3.1 8B on v6e-256
4366

67+
<!-- TODO(https://github.com/AI-Hypercomputer/torchprime/issues/135): publish perf data. -->
68+
69+
Recipe for global batch size 512, sequence length 8192:
70+
71+
```sh
72+
tp run torchprime/torch_xla_models/train.py \
73+
model=llama-3.1-8b \
74+
global_batch_size=512 \
75+
block_size=8192 \
76+
profile_step=5 \
77+
ici_mesh.fsdp=256
78+
```
79+
80+
### Llama 3.1 405B on v6e-256
81+
82+
Recipe for global batch size 64, sequence length 8192.
83+
Expected step duration: 27.349s. MFU: 21.48%.
84+
85+
```sh
86+
export LIBTPU_INIT_ARGS='--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=98304'
87+
88+
tp run torchprime/torch_xla_models/train.py \
89+
model=llama-3.1-405b \
90+
global_batch_size=64 \
91+
block_size=8192 \
92+
profile_step=5 \
93+
ici_mesh.fsdp=64 \
94+
ici_mesh.tensor=4
95+
```
96+
97+
### Llama 3.1 405B on 2 pods of v6e-256
98+
99+
Recipe for global batch size 128, sequence length 8192. We need to use a larger
100+
dataset and profile later for longer for the DCN performance to stabilize.
101+
102+
Expected step duration: 30.933s. MFU: 18.99%.
103+
104+
```sh
105+
export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_decompose_all_gather_einsum=true --xla_tpu_decompose_einsum_reduce_scatter=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true'
106+
107+
tp run torchprime/torch_xla_models/train.py \
108+
model=llama-3.1-405b \
109+
global_batch_size=128 \
110+
dcn_mesh.fsdp=2 \
111+
ici_mesh.fsdp=64 \
112+
ici_mesh.tensor=4 \
113+
dataset_config_name=wikitext-103-raw-v1 \
114+
profile_step=15 \
115+
profile_duration=240000 \
116+
max_steps=50 \
117+
logging_steps=10
118+
```
119+
120+
### Mixtral 8x7B on v6e-256
121+
122+
<!-- TODO(https://github.com/AI-Hypercomputer/torchprime/issues/137): publish perf data -->
123+
124+
Recipe for global batch size 512, sequence length 8192.
125+
126+
```sh
127+
tp run torchprime/torch_xla_models/train.py \
128+
model=mixtral-8x7b \
129+
global_batch_size=512 \
130+
ici_mesh.fsdp=256 \
131+
dataset_config_name=wikitext-103-raw-v1 \
132+
profile_step=5
133+
```
44134

45135
## Key Components
46136

47137
- `train.py`: Main training script that sets up the model, data, and training loop
48138
- `configs/base.yaml`: Configuration file for the training script
49-
- `configs/model`: Configuration files for the training models
50-
- `llama/model.py`: Implementation of the Llama model
139+
- `configs/model`: Configuration files for models
140+
- `configs/model/scaling`: Configuration files for scaling the training of a model, e.g.
141+
rematerialization, sharding tensors.
142+
- `llama/model.py`: Implementation of the Llama model family
143+
- `mixtral/model.py`: Implementation of the Mixtral model family
144+
145+
[1]: https://github.com/pytorch/xla
146+
[README-examples]: ../../README.md#examples
147+
[distributed-training]: ../../README.md#distributed-training

0 commit comments

Comments
 (0)