Skip to content

Commit 2887250

Browse files
[Async TP] Add new async TP benchmarks (#1299)
## Summary - Add benchmarks measuring perf improvement of async TP over vanilla TP baseline, for Llama3.1 8b and 70b - Benchmarks performed on Grand Teton Platform (H100s) using pytorch/torchtitan/torchao builds from 6/13 - 6/14 (see doc for more details).
1 parent 1c5be28 commit 2887250

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
The following performance benchmarks were done by the PyTorch team in June 2025, to measure the performance improvements of async TP over the vanilla TP baseline.
2+
3+
### Models
4+
5+
Llama 3.1 8B, 70B
6+
7+
### Hardware
8+
9+
We ran our performance benchmarks on the [Grand Teton platform](https://engineering.fb.com/2022/10/18/open-source/ocp-summit-2022-grand-teton/), where
10+
- Each host has 8 NVIDIA H100 GPUs fully connected with NVLink.
11+
- Each H100 GPU is equipped with 96GB HBM2e with 2.4 TB/sec peak memory bandwidth.
12+
- Hosts are inter-connected with backend RDMA network with 400 Gb/s per GPU.
13+
- We used the default 500W power limit, although tuning it up to 700W TDP can potentially provide further speedups.
14+
15+
16+
### Results
17+
18+
Detailed performance results and training configurations can be found in the tables below:
19+
20+
#### Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16
21+
22+
| Quantization | Vanilla TP tokens/sec | Async TP tokens/sec | Async TP speedup
23+
| :---------------- | :---- | :---- | :--- |
24+
| None (bfloat16) | 597.3 | 652.4 | 1.09 |
25+
| float8 tensorwise | 809.8 | 942.4 | 1.16 |
26+
| float8 rowwise | 599.6 | 624.8 | 1.04 |
27+
28+
#### Llama3 8b on 64 H100s with FSDP=8, TP=8, torch.compile, per op SAC, local batch size 12
29+
30+
| Quantization | Vanilla TP tokens/sec | Async TP tokens/sec | Async TP speedup
31+
| :---------------- | :----- | :----- | :--- |
32+
| None (bfloat16) | 4378 | 4809.4 | 1.10 |
33+
| float8 tensorwise | 5078.1 | 5570.1 | 1.10 |
34+
| float8 rowwise | 3708.5 | 3914.9 | 1.06 |
35+
36+
**Note**: the low baseline performance of the vanilla TP float8 rowwise training is being addressed here: https://github.com/pytorch/torchtitan/issues/1207
37+
38+
### Versions and Dates
39+
40+
| repo | commit | date |
41+
| --- | --- | --- |
42+
| torch | [38410cf9](https://github.com/pytorch/pytorch/commit/38410cf9b57079f3360c1e79601973a01cb2588c) | 2025/06/14 |
43+
| torchao | [6243040](https://github.com/pytorch/ao/commit/6243040807b9ceee889a58cba8e68c5fc4e2ebd8) | 2024/06/13 |
44+
| torchtitan | [820504e](https://github.com/pytorch/torchtitan/commit/820504e20d1149fbf0b98c567af24c4b0433b22d) | 2024/06/13 |

0 commit comments

Comments
 (0)