-
Notifications
You must be signed in to change notification settings - Fork 426
dp2ep Expert Parallel #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dp2ep Expert Parallel #1324
Conversation
547ecae
to
792f7a8
Compare
0f975fa
to
b517001
Compare
torchtitan/components/optimizer.py
Outdated
@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None: | |||
super().zero_grad(*args, **kwargs) | |||
|
|||
|
|||
class ExpertParallelOptimizersContainer(OptimizersContainer): | |||
""" | |||
This class is created to support fused optimizer implementation for Expert Parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm do we really need this container? I thought after my PR pytorch/pytorch#147869 earlier this year, we should be able to run fused/foreach optimizer on DTensors that lives on different device mesh. Are you hitting a similar issue in pytorch/pytorch#153268?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After another look, I think indeed I'm hitting the same issue as pytorch/pytorch#153268 -- the error is on aten._fused_adam_.default
(sorry I thought it was more elementary ops like the gradient norm clipping ones).
|
||
|
||
@torch.no_grad() | ||
def _clip_grad_norm_with_ep( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain/document why we need this? Is it a similar issue to the optimizer? If so, IMO we should fix DTensor instead of adding all those wrappers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is similar issue to the optimizer, but not exactly the same.
The cross mesh problem first happens at aten.stack
https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L102
Do you think we should support cross-mesh computation by DTensor for these more "elementary" ops? It might be easier if gradient norm computing / clipping come with fused ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten.stack
looks very tricky indeed... it produces one single tensor, and the inputs are from different mesh? Do you roughly have what the input sharding be look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if it's unavoidable. then we can make a separate clip_grad_norm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have 2 tensors on different meshes but the meshes are just different views of the same root mesh, is it possible to canonicalize them all in terms of the root and use the root for the output of stack?
torchtitan/components/optimizer.py
Outdated
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) | ||
|
||
if parallel_dims.ep_enabled and fused: | ||
if ft_manager.enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what's the requirement to be compatible with torchft?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we support fused optimizer using DTensor instead of the wrapper code, it should be compatible with torchft.
@@ -24,40 +26,6 @@ | |||
|
|||
# implementation of Tensor Parallel for the GroupedExperts in MoE | |||
class TensorParallel(ParallelStyle): | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why those methods are deleted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In ExpertParallel
and ExpertTensorParallel
classes:
- I didn't support specifying input (output) placements, as in dp2ep EP they are always expect data sharded on batch dim and perform all-to-all's to dispatch (combine, respectively) tokens to (from, respectively) the corresponding experts.
- I had to convert parameters from DTensor to plain tensor before doing computation, where the inputs always stay as plain tensors.
Since TensorParallel
is created only for the experts rather than with general purpose, I deliberately make the style consistent with ExpertParallel
and ExpertTensorParallel
, meaning
- it doesn't require input_layouts/output_layouts annotation -- it always expect inputs to be
Replicate
and outputs to bePartial
; - it doesn't convert inputs to DTensors in the forward pre hook, o/w it won't be consistent with plain Tensor parameters during computation.
Basically TensorParallel
is a specialized ParallelStyle
(combining Colwise
and Rowwise
into one), just like ExpertParallel
and ExpertTensorParallel
. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that make sense: one thing I want to ask, The TensorParallel does not support the "Sequence Parallel" type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol
Outside MoE module it always supports Sequence Parallel, e.g. on the RMSNorm before MoE. The all-gather along seq_len
dim happens on the PrepareModuleInputOutput
for MoE module.
Inside MoE, there is opportunity to do Sequence Parallel comms (and maybe also compute) outside routed experts at the cost of an extra pair of (AG, RS). It is not available in this PR -- it will be the next step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @wanchaol for the comments. I think for the fuse optimizer step, we may let DTensor support it in the "foreach" way. For others I'd love to hear more thoughts from you.
torchtitan/components/optimizer.py
Outdated
@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None: | |||
super().zero_grad(*args, **kwargs) | |||
|
|||
|
|||
class ExpertParallelOptimizersContainer(OptimizersContainer): | |||
""" | |||
This class is created to support fused optimizer implementation for Expert Parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After another look, I think indeed I'm hitting the same issue as pytorch/pytorch#153268 -- the error is on aten._fused_adam_.default
(sorry I thought it was more elementary ops like the gradient norm clipping ones).
torchtitan/components/optimizer.py
Outdated
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) | ||
|
||
if parallel_dims.ep_enabled and fused: | ||
if ft_manager.enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we support fused optimizer using DTensor instead of the wrapper code, it should be compatible with torchft.
|
||
|
||
@torch.no_grad() | ||
def _clip_grad_norm_with_ep( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is similar issue to the optimizer, but not exactly the same.
The cross mesh problem first happens at aten.stack
https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L102
Do you think we should support cross-mesh computation by DTensor for these more "elementary" ops? It might be easier if gradient norm computing / clipping come with fused ops.
@@ -24,40 +26,6 @@ | |||
|
|||
# implementation of Tensor Parallel for the GroupedExperts in MoE | |||
class TensorParallel(ParallelStyle): | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In ExpertParallel
and ExpertTensorParallel
classes:
- I didn't support specifying input (output) placements, as in dp2ep EP they are always expect data sharded on batch dim and perform all-to-all's to dispatch (combine, respectively) tokens to (from, respectively) the corresponding experts.
- I had to convert parameters from DTensor to plain tensor before doing computation, where the inputs always stay as plain tensors.
Since TensorParallel
is created only for the experts rather than with general purpose, I deliberately make the style consistent with ExpertParallel
and ExpertTensorParallel
, meaning
- it doesn't require input_layouts/output_layouts annotation -- it always expect inputs to be
Replicate
and outputs to bePartial
; - it doesn't convert inputs to DTensors in the forward pre hook, o/w it won't be consistent with plain Tensor parameters during computation.
Basically TensorParallel
is a specialized ParallelStyle
(combining Colwise
and Rowwise
into one), just like ExpertParallel
and ExpertTensorParallel
. Let me know what you think.
## Context Mostly adapted from llama4, change the TP plan based on the difference between deepseek-v3 and llama. Thanks @tianyu-l for the detailed walk through about deepseek-v3 attention model and TP plan! This diff is currently based on #1324 , and we want to extract the MoE model in DSV3 and llama4 in a shared place. Now we have: 1. FSDP 2. Activation Checkpointing 3. TP 4. CP in progress (hang due to some reason) ## Next Step: 1. Make CP work ## Verification There are minor issue with the numerical verification: With deterministic seed, the loss is not identical. I used `AdamW` optimizer. 1. FSDP degree=4 (blue line) 2. FSDP degree=4, TP degree = 2 (orange line) <img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM" src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9" /> With `Adam` optimizer, the loss is **exactly the same**: <img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM" src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb" /> --------- Co-authored-by: Tianyu Liu <lty@fb.com>
) This is to unblock "dp2ep" Expert Parallel + TP integration in torchtitan pytorch/torchtitan#1324. It does two things: 1. Slightly modifies the glue code for FSDP/HSDP + TP to work with FSDP/HSDP + EP and FSDP/HSDP + EP + TP. I kept the name `FSDPParam._tp_spec` to make the change minimal. We can consider renaming it in the future if it confuses people, but I heard @wanchaol has a plan to rewrite DTensor strided sharding entirely. 2. Lifts the check of `_validate_tp_mesh_dim` for `torch.distributed.tensor.parallel.parallelize_module`, as in EP or EP+TP this check is too strict. In particular it assumes a DeviceMesh must have `mesh_dim_names` which is not always true. I'm also removing the file `torch/distributed/tensor/parallel/_utils.py` it belongs entirely, as the other check `_deprecate_warnings`, added two years ago, is not used any more. Pull Request resolved: #157216 Approved by: https://github.com/wanchaol, https://github.com/weifengpy
…iple meshes (#157682) We are seeing more and more use cases where parameters in a model (under the same optimizer group) are put on different meshes. E.g. - when FSDP and TP are both applied, some parameters are sharded only on the FSDP mesh but not TP mesh (see #153268). - in [dp2ep Expert Parallel](pytorch/torchtitan#1324), the routed experts are sharded on the (global FSDP \ EP) mesh for smaller FSDP and on the EP mesh for EP, whereas other params are sharded on the global FSDP mesh for FSDP. This PR is, in some sense, a continuation of #147869 to tackle the problem when fused optimizers are used. In such cases, the [`fused_adam`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L15786) / `fused_adamw` has a scalar tensor arg `state_steps` which gets automatically cast to DTensor on the default [`compute_mesh`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_dispatch.py#L350) (one of the multiple meshes), even though the it could correspond to different meshes. To avoid hitting the cross-mesh propagation exception in `common_pointwise_strategy` and followup redistribute problems, we manually set the target mesh and placements to be the same as input mesh and placements, so that no redistribute will be triggered. This also helps bypass the situation where [`generate_redistribute_costs`](https://github.com/pytorch/pytorch/pull/157682/files#diff-eea32a36dd2d4e58307bc5229402e48048b2ecaef64a7c085495fba1ee10ac89R597) returns infinite cost due to cross mesh redistribute. Moreover, this PR has minimal scope (restricted to the `fused_ops`) and doesn't need to modify other files such as `_sharding_prop.py`. Pull Request resolved: #157682 Approved by: https://github.com/wanchaol
## Context Mostly adapted from llama4, change the TP plan based on the difference between deepseek-v3 and llama. Thanks @tianyu-l for the detailed walk through about deepseek-v3 attention model and TP plan! This diff is currently based on pytorch#1324 , and we want to extract the MoE model in DSV3 and llama4 in a shared place. Now we have: 1. FSDP 2. Activation Checkpointing 3. TP 4. CP in progress (hang due to some reason) ## Next Step: 1. Make CP work ## Verification There are minor issue with the numerical verification: With deterministic seed, the loss is not identical. I used `AdamW` optimizer. 1. FSDP degree=4 (blue line) 2. FSDP degree=4, TP degree = 2 (orange line) <img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM" src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9" /> With `Adam` optimizer, the loss is **exactly the same**: <img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM" src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb" /> --------- Co-authored-by: Tianyu Liu <lty@fb.com>
Mostly adapted from llama4, change the TP plan based on the difference between deepseek-v3 and llama. Thanks @tianyu-l for the detailed walk through about deepseek-v3 attention model and TP plan! This diff is currently based on #1324 , and we want to extract the MoE model in DSV3 and llama4 in a shared place. Now we have: 1. FSDP 2. Activation Checkpointing 3. TP 4. CP in progress (hang due to some reason) 1. Make CP work There are minor issue with the numerical verification: With deterministic seed, the loss is not identical. I used `AdamW` optimizer. 1. FSDP degree=4 (blue line) 2. FSDP degree=4, TP degree = 2 (orange line) <img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM" src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9" /> With `Adam` optimizer, the loss is **exactly the same**: <img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM" src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb" /> --------- Co-authored-by: Tianyu Liu <lty@fb.com>
**Overview** Previously I demonstrated Expert Parallel for expert-choice MoE in a stack of PRs pytorch#732. This PR adds the initial support of dp2ep Expert Parallel for token-choice MoE, being non-intrusive to model code and composable with other parallelisms. In particular: - FSDP/HSDP + TP + EP is unblocked by pytorch/pytorch#157216 - fused optimizer for dp2ep EP is unblocked by pytorch/pytorch#157682 This PR also fixes the issue between auxiliary-loss-free load balancing and gradient accumulation, partly inspired by the solution of @hann-wang in pytorch#1304 which originally pointed out the issue. This PR does the expert bias update in an optimizer hook, instead of adding another entry in `TrainSpec`. While working on this PR, I also identified numerical issues between AdamW and Tensor Parallel, which I will post in a separate issue to track. **What is dp2ep Expert Parallel** Here are two diagrams illustrating the communication / computation pattern happening in dp2ep Expert Parallel. Basically, the Expert Parallel degree needed for MoE routed experts is borrowed from the Data Parallel (including Context Parallel) degree for non-MoE params (e.g. Attention layers, MLP layers) and other params in MoE layers (including the router's gate and shared experts). without TP  with TP  **Note:** In the current implementation, the all-to-all communication across all TP ranks are duplicate, causing unnecessary communication overhead. As the next step, I'm going to implement the "Sequence Parallel" for the all-to-all, reducing the communication volume to `1 / tp_degree`. **Design** The EP utilizes DTensor's [`parallelize_module`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/api.py#L16) API to shard MoE routed experts on the `num_expert` dimension, and inserts a pair of hooks before and after forward to perform all-to-all collectives. In additional, this PR creates an `expert_parallel` wrapper applied to the GroupedExperts computation, serving the following three purposes: 1. Convert parameters from DTensors to plain Tensors, to work with dynamic-shape inputs which cannot be easily expressed as DTensors. 2. In Expert Parallel, apply the `generate_permute_indices` kernel to permute the inputs to be ordered by local experts (see the `_token_dispatch` function in `ExpertParallel`) and permute the outputs back. 3. In order to use `torch._grouped_mm`, we need to make sure the number of tokens each expert gets is a multiple of `ALIGN_SIZE_M`. The `generate_permute_indices` kernel also helps achieve this via padding, without incurring synchronization between device and host. Note that this will create side effects when wrapping the for-loop implementation of GroupedExperts, as it does not need padding. 4. Among the above: - 1 and 2 are needed only when `expert_parallel_degree` > 1. - 3 is needed even for single-device computation. - 2 can be moved to `ExpertParallel`'s `_token_dispatch` if not coupled with 3. Due to the inhomogeneity of `DeviceMesh`es from EP parameters and non-EP parameters, this PR adds the following special treatment to enable TP - `DeviceMesh` creation: when EP is enabled, create a special `DeviceMesh` to share between DP/CP (for non-EP parameters) and EP (for EP parameters). - gradient norm clipping: when EP is enabled, separately compute the norm of EP parameters and non-EP parameters -> compute the global norm -> separately perform grad norm clipping with the global norm. - ~~fused optimizer step: created a new optimizer container class `ExpertParallelOptimizersContainer` which does fused optimizer steps on EP parameters and non-EP parameters separately.~~ (tackled in pytorch/pytorch#157682) For `DeviceMesh`, we'll need to improve the way we can express non-homogeneous meshes. For gradient norm clipping ~~and fused optimizer~~, since there are up two groups of parameters, I expect the approach to be fine, until we find better way of support. Things could change if LLM / MoE architecture evolves to be more dynamic. **Communication Trace Verification**  One can see that in order to call EP all-to-all `_token_dispatch` and `_token_combine` with correct `input_splits` and `output_splits`, we need to generate the size data via another `dist.all_to_all_single` (in the default stream) and do a **device-to-host sync**. This can be avoided by utilizing SymmetricMemory-based `all-to-all-v`, which we will work on soon. **DCP Resharding Correctness and Numerical Verification** Note: I used `--optimizer.name="Adam"` instead of `"AdamW"` which seems to cause numerical issues when TP is enabled. To verify, I created a seed checkpoint of the debug model, fixed the seed, and ran the same training under different parallelism configs for 100 steps on at most 8 GPUs - FSDP 2 - FSDP 2 (EP 2), TP 2, PP 2 - HSDP 4 (DP 2, CP 2, EP 4), TP 2 <img width="1317" alt="image" src="https://github.com/user-attachments/assets/609f057c-0e6a-430a-89dc-5f2070ecb135" /> **Next Steps** - Sequence Parallel for all-to-all communication collectives, when TP is enabled (at the cost of another pair of TP all-gather and reduce-scatter) - adopt SymmetricMemory-based all-to-all and avoid D2H syncs (cc @kwen2501) - enable EP in torchtitan's DeepSeekV3 @wwwjn - FSDP2 non-dim-0 sharding (cc @weifengpy) - `torch.compile` support @xmfan - which blocks torchao quantization enablement - computation / communication overlapping - either via inductor passes to overlap all-to-all with shared expert computation @xmfan - or via fine-grained Pipeline Parallel splitting & scheduling @H-Huang - float8 + MoE TP integration @danielvegamyhre - Previously float8 works with TP by having specialized `ColwiseParallel` and `RowwiseParallel` (see [code](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L167)). For MoE, I'm creating new ad hoc `ParallelStyle`s, including `TensorParallel`, `ExpertParallel`, and `ExpertTensorParallel`. - better `DeviceMesh` support and general "ETP" support (where experts TP and attention/mlp TP don't have to have the same TP degree) @fduwjj
) This is to unblock "dp2ep" Expert Parallel + TP integration in torchtitan pytorch/torchtitan#1324. It does two things: 1. Slightly modifies the glue code for FSDP/HSDP + TP to work with FSDP/HSDP + EP and FSDP/HSDP + EP + TP. I kept the name `FSDPParam._tp_spec` to make the change minimal. We can consider renaming it in the future if it confuses people, but I heard @wanchaol has a plan to rewrite DTensor strided sharding entirely. 2. Lifts the check of `_validate_tp_mesh_dim` for `torch.distributed.tensor.parallel.parallelize_module`, as in EP or EP+TP this check is too strict. In particular it assumes a DeviceMesh must have `mesh_dim_names` which is not always true. I'm also removing the file `torch/distributed/tensor/parallel/_utils.py` it belongs entirely, as the other check `_deprecate_warnings`, added two years ago, is not used any more. Pull Request resolved: #157216 Approved by: https://github.com/wanchaol, https://github.com/weifengpy
Mostly adapted from llama4, change the TP plan based on the difference between deepseek-v3 and llama. Thanks @tianyu-l for the detailed walk through about deepseek-v3 attention model and TP plan! This diff is currently based on #1324 , and we want to extract the MoE model in DSV3 and llama4 in a shared place. Now we have: 1. FSDP 2. Activation Checkpointing 3. TP 4. CP in progress (hang due to some reason) 1. Make CP work There are minor issue with the numerical verification: With deterministic seed, the loss is not identical. I used `AdamW` optimizer. 1. FSDP degree=4 (blue line) 2. FSDP degree=4, TP degree = 2 (orange line) <img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM" src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9" /> With `Adam` optimizer, the loss is **exactly the same**: <img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM" src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb" /> --------- Co-authored-by: Tianyu Liu <lty@fb.com>
## Supported Features - FSDP, HSDP - Activation checkpointing - Tensor Parallel (TP) from @tianyu-l - Expert Parallel (EP) ## To be added - Modeling - Merge DeepSeek-V3 and Llama4 MoE common components - Parallelism - Context Parallel support for DeepSeek-V3 - PP support for DeepSeek-V3 @H-Huang is working on #1345 - torch.compile - Quantization - Testing - perfomance and loss converging tests - CI integration - @wwwjn will work on this after PyTorch side diffs (mentioned in #1324) get into PyTorch nightly ## Test 1. With FSDP=8, EP=2 (['dp_shard_mod_ep', 'dp_shard_in_ep'], [4, 2]) ``` [rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - step: 1 loss: 12.2616 grad_norm: 0.3918 memory: 65.53GiB(68.98%) tps: 1,482 tflops: 0.61 mfu: 0.06% [rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-08 15:15:43,543 - root - INFO - step: 2 loss: 12.0093 grad_norm: 0.5745 memory: 65.54GiB(68.99%) tps: 69,111 tflops: 28.68 mfu: 2.90% [rank0]:[titan] 2025-07-08 15:15:43,981 - root - INFO - step: 3 loss: 11.1697 grad_norm: 1.2095 memory: 65.54GiB(68.99%) tps: 74,931 tflops: 31.09 mfu: 3.14% [rank0]:[titan] 2025-07-08 15:15:44,015 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:15:44,409 - root - INFO - step: 4 loss: 10.7248 grad_norm: 1.2230 memory: 65.54GiB(68.99%) tps: 76,668 tflops: 31.81 mfu: 3.22% [rank0]:[titan] 2025-07-08 15:15:44,838 - root - INFO - step: 5 loss: 10.5484 grad_norm: 1.1633 memory: 65.54GiB(68.99%) tps: 76,416 tflops: 31.71 mfu: 3.21% [rank0]:[titan] 2025-07-08 15:15:45,339 - root - INFO - step: 6 loss: 10.3509 grad_norm: 1.1611 memory: 65.54GiB(68.99%) tps: 65,490 tflops: 27.18 mfu: 2.75% [rank0]:[titan] 2025-07-08 15:15:45,401 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:15:46,121 - root - INFO - step: 7 loss: 10.2153 grad_norm: 1.1410 memory: 65.54GiB(68.99%) tps: 41,934 tflops: 17.40 mfu: 1.76% [rank0]:[titan] 2025-07-08 15:15:46,733 - root - INFO - step: 8 loss: 10.0801 grad_norm: 1.1487 memory: 65.54GiB(68.99%) tps: 53,599 tflops: 22.24 mfu: 2.25% [rank0]:[titan] 2025-07-08 15:15:47,137 - root - INFO - step: 9 loss: 9.9781 grad_norm: 1.1257 memory: 65.54GiB(68.99%) tps: 81,051 tflops: 33.63 mfu: 3.40% [rank0]:[titan] 2025-07-08 15:15:47,554 - root - INFO - step: 10 loss: 9.9183 grad_norm: 1.1012 memory: 65.54GiB(68.99%) tps: 78,712 tflops: 32.66 mfu: 3.30% ``` 2. With FSDP=4, TP=2 ``` [rank0]:[titan] 2025-07-08 15:16:25,927 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - step: 1 loss: 12.2768 grad_norm: 0.3836 memory: 41.14GiB(43.31%) tps: 1,750 tflops: 0.73 mfu: 0.07% [rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-08 15:16:35,310 - root - INFO - step: 2 loss: 12.0284 grad_norm: 0.5423 memory: 41.29GiB(43.46%) tps: 51,796 tflops: 21.49 mfu: 2.17% [rank0]:[titan] 2025-07-08 15:16:35,605 - root - INFO - step: 3 loss: 11.2398 grad_norm: 1.2037 memory: 41.29GiB(43.46%) tps: 55,575 tflops: 23.06 mfu: 2.33% [rank0]:[titan] 2025-07-08 15:16:35,912 - root - INFO - step: 4 loss: 10.8246 grad_norm: 1.2360 memory: 41.29GiB(43.46%) tps: 53,553 tflops: 22.22 mfu: 2.25% [rank0]:[titan] 2025-07-08 15:16:36,206 - root - INFO - step: 5 loss: 10.6295 grad_norm: 1.1951 memory: 41.29GiB(43.46%) tps: 55,732 tflops: 23.13 mfu: 2.34% [rank0]:[titan] 2025-07-08 15:16:36,502 - root - INFO - step: 6 loss: 10.5240 grad_norm: 1.1296 memory: 41.29GiB(43.46%) tps: 55,564 tflops: 23.06 mfu: 2.33% [rank0]:[titan] 2025-07-08 15:16:36,793 - root - INFO - step: 7 loss: 10.3426 grad_norm: 1.1630 memory: 41.29GiB(43.46%) tps: 56,295 tflops: 23.36 mfu: 2.36% [rank0]:[titan] 2025-07-08 15:16:36,824 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:16:37,081 - root - INFO - step: 8 loss: 10.2127 grad_norm: 1.1499 memory: 41.29GiB(43.46%) tps: 57,052 tflops: 23.67 mfu: 2.39% [rank0]:[titan] 2025-07-08 15:16:37,374 - root - INFO - step: 9 loss: 10.0537 grad_norm: 1.1814 memory: 41.29GiB(43.46%) tps: 56,019 tflops: 23.25 mfu: 2.35% [rank0]:[titan] 2025-07-08 15:16:37,664 - root - INFO - step: 10 loss: 10.0311 grad_norm: 1.1082 memory: 41.29GiB(43.46%) tps: 56,504 tflops: 23.45 mfu: 2.37% ``` --------- Co-authored-by: Tianyu Liu <lty@fb.com> Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Overview
Previously I demonstrated Expert Parallel for expert-choice MoE in a stack of PRs #732.
This PR adds the initial support of dp2ep Expert Parallel for token-choice MoE, being non-intrusive to model code and composable with other parallelisms. In particular:
This PR also fixes the issue between auxiliary-loss-free load balancing and gradient accumulation, partly inspired by the solution of @hann-wang in #1304 which originally pointed out the issue. This PR does the expert bias update in an optimizer hook, instead of adding another entry in
TrainSpec
.While working on this PR, I also identified numerical issues between AdamW and Tensor Parallel, which I will post in a separate issue to track.
What is dp2ep Expert Parallel
Here are two diagrams illustrating the communication / computation pattern happening in dp2ep Expert Parallel. Basically, the Expert Parallel degree needed for MoE routed experts is borrowed from the Data Parallel (including Context Parallel) degree for non-MoE params (e.g. Attention layers, MLP layers) and other params in MoE layers (including the router's gate and shared experts).
without TP

with TP

Note: In the current implementation, the all-to-all communication across all TP ranks are duplicate, causing unnecessary communication overhead. As the next step, I'm going to implement the "Sequence Parallel" for the all-to-all, reducing the communication volume to
1 / tp_degree
.Design
The EP utilizes DTensor's
parallelize_module
API to shard MoE routed experts on thenum_expert
dimension, and inserts a pair of hooks before and after forward to perform all-to-all collectives.In additional, this PR creates an
expert_parallel
wrapper applied to the GroupedExperts computation, servingthe following three purposes:
generate_permute_indices
kernel to permute the inputs to be ordered by local experts (see the_token_dispatch
function inExpertParallel
) and permute the outputs back.torch._grouped_mm
, we need to make sure the number of tokens each expert gets is a multiple ofALIGN_SIZE_M
. Thegenerate_permute_indices
kernel also helps achieve this via padding, without incurring synchronization between device and host. Note that this will create side effects when wrapping the for-loop implementation of GroupedExperts, as it does not need padding.expert_parallel_degree
> 1.ExpertParallel
's_token_dispatch
if not coupled with 3.Due to the inhomogeneity of
DeviceMesh
es from EP parameters and non-EP parameters, this PR adds the following special treatment to enable TPDeviceMesh
creation: when EP is enabled, create a specialDeviceMesh
to share between DP/CP (for non-EP parameters) and EP (for EP parameters).fused optimizer step: created a new optimizer container class(tackled in [dtensor] add support for fused optimizer with parameters across multiple meshes pytorch#157682)ExpertParallelOptimizersContainer
which does fused optimizer steps on EP parameters and non-EP parameters separately.For
DeviceMesh
, we'll need to improve the way we can express non-homogeneous meshes. For gradient norm clippingand fused optimizer, since there are up two groups of parameters, I expect the approach to be fine, until we find better way of support. Things could change if LLM / MoE architecture evolves to be more dynamic.Communication Trace Verification
One can see that in order to call EP all-to-all
_token_dispatch
and_token_combine
with correctinput_splits
andoutput_splits
, we need to generate the size data via anotherdist.all_to_all_single
(in the default stream) and do a device-to-host sync. This can be avoided by utilizing SymmetricMemory-basedall-to-all-v
, which we will work on soon.DCP Resharding Correctness and Numerical Verification
Note: I used
--optimizer.name="Adam"
instead of"AdamW"
which seems to cause numerical issues when TP is enabled.To verify, I created a seed checkpoint of the debug model, fixed the seed, and ran the same training under different parallelism configs for 100 steps on at most 8 GPUs
Next Steps
torch.compile
support @xmfanColwiseParallel
andRowwiseParallel
(see code). For MoE, I'm creating new ad hocParallelStyle
s, includingTensorParallel
,ExpertParallel
, andExpertTensorParallel
.DeviceMesh
support and general "ETP" support (where EP borrows degree from DP/CP and TP so that experts TP and attention/mlp TP don't have to have the same TP degrees) @fduwjj