Skip to content

dp2ep Expert Parallel #1324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged

dp2ep Expert Parallel #1324

merged 1 commit into from
Jul 8, 2025

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Jun 21, 2025

Overview

Previously I demonstrated Expert Parallel for expert-choice MoE in a stack of PRs #732.

This PR adds the initial support of dp2ep Expert Parallel for token-choice MoE, being non-intrusive to model code and composable with other parallelisms. In particular:

This PR also fixes the issue between auxiliary-loss-free load balancing and gradient accumulation, partly inspired by the solution of @hann-wang in #1304 which originally pointed out the issue. This PR does the expert bias update in an optimizer hook, instead of adding another entry in TrainSpec.

While working on this PR, I also identified numerical issues between AdamW and Tensor Parallel, which I will post in a separate issue to track.

What is dp2ep Expert Parallel

Here are two diagrams illustrating the communication / computation pattern happening in dp2ep Expert Parallel. Basically, the Expert Parallel degree needed for MoE routed experts is borrowed from the Data Parallel (including Context Parallel) degree for non-MoE params (e.g. Attention layers, MLP layers) and other params in MoE layers (including the router's gate and shared experts).

without TP
image

with TP
image

Note: In the current implementation, the all-to-all communication across all TP ranks are duplicate, causing unnecessary communication overhead. As the next step, I'm going to implement the "Sequence Parallel" for the all-to-all, reducing the communication volume to 1 / tp_degree.

Design

The EP utilizes DTensor's parallelize_module API to shard MoE routed experts on the num_expert dimension, and inserts a pair of hooks before and after forward to perform all-to-all collectives.

In additional, this PR creates an expert_parallel wrapper applied to the GroupedExperts computation, serving
the following three purposes:

  1. Convert parameters from DTensors to plain Tensors, to work with dynamic-shape inputs which cannot be easily expressed as DTensors.
  2. In Expert Parallel, apply the generate_permute_indices kernel to permute the inputs to be ordered by local experts (see the _token_dispatch function in ExpertParallel) and permute the outputs back.
  3. In order to use torch._grouped_mm, we need to make sure the number of tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices kernel also helps achieve this via padding, without incurring synchronization between device and host. Note that this will create side effects when wrapping the for-loop implementation of GroupedExperts, as it does not need padding.
  4. Among the above:
    • 1 and 2 are needed only when expert_parallel_degree > 1.
    • 3 is needed even for single-device computation.
    • 2 can be moved to ExpertParallel's _token_dispatch if not coupled with 3.

Due to the inhomogeneity of DeviceMeshes from EP parameters and non-EP parameters, this PR adds the following special treatment to enable TP

  • DeviceMesh creation: when EP is enabled, create a special DeviceMesh to share between DP/CP (for non-EP parameters) and EP (for EP parameters).
  • gradient norm clipping: when EP is enabled, separately compute the norm of EP parameters and non-EP parameters -> compute the global norm -> separately perform grad norm clipping with the global norm.
  • fused optimizer step: created a new optimizer container class ExpertParallelOptimizersContainer which does fused optimizer steps on EP parameters and non-EP parameters separately. (tackled in [dtensor] add support for fused optimizer with parameters across multiple meshes pytorch#157682)

For DeviceMesh, we'll need to improve the way we can express non-homogeneous meshes. For gradient norm clipping and fused optimizer, since there are up two groups of parameters, I expect the approach to be fine, until we find better way of support. Things could change if LLM / MoE architecture evolves to be more dynamic.

Communication Trace Verification

image

One can see that in order to call EP all-to-all _token_dispatch and _token_combine with correct input_splits and output_splits, we need to generate the size data via another dist.all_to_all_single (in the default stream) and do a device-to-host sync. This can be avoided by utilizing SymmetricMemory-based all-to-all-v, which we will work on soon.

DCP Resharding Correctness and Numerical Verification

Note: I used --optimizer.name="Adam" instead of "AdamW" which seems to cause numerical issues when TP is enabled.

To verify, I created a seed checkpoint of the debug model, fixed the seed, and ran the same training under different parallelism configs for 100 steps on at most 8 GPUs

  • FSDP 2
  • FSDP 2 (EP 2), TP 2, PP 2
  • HSDP 4 (DP 2, CP 2, EP 4), TP 2
image

Next Steps

  • Sequence Parallel for all-to-all communication collectives, when TP is enabled (at the cost of another pair of TP all-gather and reduce-scatter)
  • adopt SymmetricMemory-based all-to-all and avoid D2H syncs (cc @kwen2501)
  • enable EP in torchtitan's DeepSeekV3 @wwwjn
  • FSDP2 non-dim-0 sharding (cc @weifengpy)
  • torch.compile support @xmfan
    • which blocks torchao quantization enablement
  • computation / communication overlapping
    • either via inductor passes to overlap all-to-all with shared expert computation @xmfan
    • or via fine-grained Pipeline Parallel splitting & scheduling @H-Huang
  • float8 + MoE TP integration @danielvegamyhre
    • Previously float8 works with TP by having specialized ColwiseParallel and RowwiseParallel (see code). For MoE, I'm creating new ad hoc ParallelStyles, including TensorParallel, ExpertParallel, and ExpertTensorParallel.
  • better DeviceMesh support and general "ETP" support (where EP borrows degree from DP/CP and TP so that experts TP and attention/mlp TP don't have to have the same TP degrees) @fduwjj

@tianyu-l tianyu-l requested review from fegin and wwwjn as code owners June 21, 2025 01:07
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 21, 2025
@tianyu-l tianyu-l marked this pull request as draft June 21, 2025 01:07
@tianyu-l tianyu-l force-pushed the ep branch 2 times, most recently from 547ecae to 792f7a8 Compare June 26, 2025 05:51
@tianyu-l tianyu-l force-pushed the ep branch 3 times, most recently from 0f975fa to b517001 Compare June 29, 2025 07:54
@tianyu-l tianyu-l requested a review from wanchaol June 29, 2025 08:07
@tianyu-l tianyu-l marked this pull request as ready for review June 29, 2025 08:08
@tianyu-l tianyu-l requested a review from wconstab as a code owner June 29, 2025 08:08
@tianyu-l tianyu-l changed the title [WIP] expert parallel dp2ep dp2ep Expert Parallel Jun 29, 2025
@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None:
super().zero_grad(*args, **kwargs)


class ExpertParallelOptimizersContainer(OptimizersContainer):
"""
This class is created to support fused optimizer implementation for Expert Parallel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm do we really need this container? I thought after my PR pytorch/pytorch#147869 earlier this year, we should be able to run fused/foreach optimizer on DTensors that lives on different device mesh. Are you hitting a similar issue in pytorch/pytorch#153268?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After another look, I think indeed I'm hitting the same issue as pytorch/pytorch#153268 -- the error is on aten._fused_adam_.default (sorry I thought it was more elementary ops like the gradient norm clipping ones).



@torch.no_grad()
def _clip_grad_norm_with_ep(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain/document why we need this? Is it a similar issue to the optimizer? If so, IMO we should fix DTensor instead of adding all those wrappers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is similar issue to the optimizer, but not exactly the same.
The cross mesh problem first happens at aten.stack https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L102
Do you think we should support cross-mesh computation by DTensor for these more "elementary" ops? It might be easier if gradient norm computing / clipping come with fused ops.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten.stack looks very tricky indeed... it produces one single tensor, and the inputs are from different mesh? Do you roughly have what the input sharding be look like?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if it's unavoidable. then we can make a separate clip_grad_norm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have 2 tensors on different meshes but the meshes are just different views of the same root mesh, is it possible to canonicalize them all in terms of the root and use the root for the output of stack?

return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

if parallel_dims.ep_enabled and fused:
if ft_manager.enabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what's the requirement to be compatible with torchft?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we support fused optimizer using DTensor instead of the wrapper code, it should be compatible with torchft.

@@ -24,40 +26,6 @@

# implementation of Tensor Parallel for the GroupedExperts in MoE
class TensorParallel(ParallelStyle):
def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why those methods are deleted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ExpertParallel and ExpertTensorParallel classes:

  1. I didn't support specifying input (output) placements, as in dp2ep EP they are always expect data sharded on batch dim and perform all-to-all's to dispatch (combine, respectively) tokens to (from, respectively) the corresponding experts.
  2. I had to convert parameters from DTensor to plain tensor before doing computation, where the inputs always stay as plain tensors.

Since TensorParallel is created only for the experts rather than with general purpose, I deliberately make the style consistent with ExpertParallel and ExpertTensorParallel, meaning

  1. it doesn't require input_layouts/output_layouts annotation -- it always expect inputs to be Replicate and outputs to be Partial;
  2. it doesn't convert inputs to DTensors in the forward pre hook, o/w it won't be consistent with plain Tensor parameters during computation.

Basically TensorParallel is a specialized ParallelStyle (combining Colwise and Rowwise into one), just like ExpertParallel and ExpertTensorParallel. Let me know what you think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that make sense: one thing I want to ask, The TensorParallel does not support the "Sequence Parallel" type?

Copy link
Contributor Author

@tianyu-l tianyu-l Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol
Outside MoE module it always supports Sequence Parallel, e.g. on the RMSNorm before MoE. The all-gather along seq_len dim happens on the PrepareModuleInputOutput for MoE module.

Inside MoE, there is opportunity to do Sequence Parallel comms (and maybe also compute) outside routed experts at the cost of an extra pair of (AG, RS). It is not available in this PR -- it will be the next step.

Copy link
Contributor Author

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wanchaol for the comments. I think for the fuse optimizer step, we may let DTensor support it in the "foreach" way. For others I'd love to hear more thoughts from you.

@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None:
super().zero_grad(*args, **kwargs)


class ExpertParallelOptimizersContainer(OptimizersContainer):
"""
This class is created to support fused optimizer implementation for Expert Parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After another look, I think indeed I'm hitting the same issue as pytorch/pytorch#153268 -- the error is on aten._fused_adam_.default (sorry I thought it was more elementary ops like the gradient norm clipping ones).

return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

if parallel_dims.ep_enabled and fused:
if ft_manager.enabled:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we support fused optimizer using DTensor instead of the wrapper code, it should be compatible with torchft.



@torch.no_grad()
def _clip_grad_norm_with_ep(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is similar issue to the optimizer, but not exactly the same.
The cross mesh problem first happens at aten.stack https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L102
Do you think we should support cross-mesh computation by DTensor for these more "elementary" ops? It might be easier if gradient norm computing / clipping come with fused ops.

@@ -24,40 +26,6 @@

# implementation of Tensor Parallel for the GroupedExperts in MoE
class TensorParallel(ParallelStyle):
def __init__(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ExpertParallel and ExpertTensorParallel classes:

  1. I didn't support specifying input (output) placements, as in dp2ep EP they are always expect data sharded on batch dim and perform all-to-all's to dispatch (combine, respectively) tokens to (from, respectively) the corresponding experts.
  2. I had to convert parameters from DTensor to plain tensor before doing computation, where the inputs always stay as plain tensors.

Since TensorParallel is created only for the experts rather than with general purpose, I deliberately make the style consistent with ExpertParallel and ExpertTensorParallel, meaning

  1. it doesn't require input_layouts/output_layouts annotation -- it always expect inputs to be Replicate and outputs to be Partial;
  2. it doesn't convert inputs to DTensors in the forward pre hook, o/w it won't be consistent with plain Tensor parameters during computation.

Basically TensorParallel is a specialized ParallelStyle (combining Colwise and Rowwise into one), just like ExpertParallel and ExpertTensorParallel. Let me know what you think.

@wwwjn wwwjn mentioned this pull request Jul 1, 2025
wwwjn added a commit that referenced this pull request Jul 1, 2025
wwwjn added a commit that referenced this pull request Jul 2, 2025
wwwjn added a commit that referenced this pull request Jul 2, 2025
## Context
Mostly adapted from llama4, change the TP plan based on the difference
between deepseek-v3 and llama.

Thanks @tianyu-l for the detailed walk through about deepseek-v3
attention model and TP plan! This diff is currently based on #1324 , and
we want to extract the MoE model in DSV3 and llama4 in a shared place.

Now we have:
1. FSDP
2. Activation Checkpointing
3. TP
4. CP in progress (hang due to some reason)

## Next Step:
1. Make CP work

## Verification
There are minor issue with the numerical verification: With
deterministic seed, the loss is not identical. I used `AdamW` optimizer.

1. FSDP degree=4 (blue line)
2. FSDP degree=4, TP degree = 2 (orange line)

<img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM"
src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9"
/>


With `Adam` optimizer, the loss is **exactly the same**:
<img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM"
src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb"
/>

---------

Co-authored-by: Tianyu Liu <lty@fb.com>
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 8, 2025
)

This is to unblock "dp2ep" Expert Parallel + TP integration in torchtitan pytorch/torchtitan#1324.

It does two things:
1. Slightly modifies the glue code for FSDP/HSDP + TP to work with FSDP/HSDP + EP and FSDP/HSDP + EP + TP. I kept the name `FSDPParam._tp_spec` to make the change minimal. We can consider renaming it in the future if it confuses people, but I heard @wanchaol has a plan to rewrite DTensor strided sharding entirely.
2. Lifts the check of `_validate_tp_mesh_dim` for `torch.distributed.tensor.parallel.parallelize_module`, as in EP or EP+TP this check is too strict. In particular it assumes a DeviceMesh must have `mesh_dim_names` which is not always true. I'm also removing the file `torch/distributed/tensor/parallel/_utils.py` it belongs entirely, as the other check `_deprecate_warnings`, added two years ago, is not used any more.

Pull Request resolved: #157216
Approved by: https://github.com/wanchaol, https://github.com/weifengpy
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 8, 2025
…iple meshes (#157682)

We are seeing more and more use cases where parameters in a model (under the same optimizer group) are put on different meshes. E.g.
- when FSDP and TP are both applied, some parameters are sharded only on the FSDP mesh but not TP mesh (see #153268).
- in [dp2ep Expert Parallel](pytorch/torchtitan#1324), the routed experts are sharded on the (global FSDP \ EP) mesh for smaller FSDP and on the EP mesh for EP, whereas other params are sharded on the global FSDP mesh for FSDP.

This PR is, in some sense, a continuation of #147869 to tackle the problem when fused optimizers are used. In such cases, the [`fused_adam`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L15786) / `fused_adamw` has a scalar tensor arg `state_steps` which gets automatically cast to DTensor on the default [`compute_mesh`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_dispatch.py#L350) (one of the multiple meshes), even though the it could correspond to different meshes.

To avoid hitting the cross-mesh propagation exception in `common_pointwise_strategy` and followup redistribute problems, we manually set the target mesh and placements to be the same as input mesh and placements, so that no redistribute will be triggered. This also helps bypass the situation where [`generate_redistribute_costs`](https://github.com/pytorch/pytorch/pull/157682/files#diff-eea32a36dd2d4e58307bc5229402e48048b2ecaef64a7c085495fba1ee10ac89R597) returns infinite cost due to cross mesh redistribute.

Moreover, this PR has minimal scope (restricted to the `fused_ops`) and doesn't need to modify other files such as `_sharding_prop.py`.

Pull Request resolved: #157682
Approved by: https://github.com/wanchaol
H-Huang pushed a commit to H-Huang/torchtitan that referenced this pull request Jul 8, 2025
## Context
Mostly adapted from llama4, change the TP plan based on the difference
between deepseek-v3 and llama.

Thanks @tianyu-l for the detailed walk through about deepseek-v3
attention model and TP plan! This diff is currently based on pytorch#1324 , and
we want to extract the MoE model in DSV3 and llama4 in a shared place.

Now we have:
1. FSDP
2. Activation Checkpointing
3. TP
4. CP in progress (hang due to some reason)

## Next Step:
1. Make CP work

## Verification
There are minor issue with the numerical verification: With
deterministic seed, the loss is not identical. I used `AdamW` optimizer.

1. FSDP degree=4 (blue line)
2. FSDP degree=4, TP degree = 2 (orange line)

<img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM"
src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9"
/>


With `Adam` optimizer, the loss is **exactly the same**:
<img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM"
src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb"
/>

---------

Co-authored-by: Tianyu Liu <lty@fb.com>
@tianyu-l tianyu-l merged commit 01f4e50 into main Jul 8, 2025
7 checks passed
@tianyu-l tianyu-l deleted the ep branch July 8, 2025 16:47
wwwjn added a commit that referenced this pull request Jul 8, 2025
Mostly adapted from llama4, change the TP plan based on the difference
between deepseek-v3 and llama.

Thanks @tianyu-l for the detailed walk through about deepseek-v3
attention model and TP plan! This diff is currently based on #1324 , and
we want to extract the MoE model in DSV3 and llama4 in a shared place.

Now we have:
1. FSDP
2. Activation Checkpointing
3. TP
4. CP in progress (hang due to some reason)

1. Make CP work

There are minor issue with the numerical verification: With
deterministic seed, the loss is not identical. I used `AdamW` optimizer.

1. FSDP degree=4 (blue line)
2. FSDP degree=4, TP degree = 2 (orange line)

<img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM"
src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9"
/>

With `Adam` optimizer, the loss is **exactly the same**:
<img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM"
src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb"
/>

---------

Co-authored-by: Tianyu Liu <lty@fb.com>
mori360 pushed a commit to mori360/torchtitan that referenced this pull request Jul 8, 2025
**Overview**

Previously I demonstrated Expert Parallel for expert-choice MoE in a
stack of PRs pytorch#732.

This PR adds the initial support of dp2ep Expert Parallel for
token-choice MoE, being non-intrusive to model code and composable with
other parallelisms. In particular:
- FSDP/HSDP + TP + EP is unblocked by
pytorch/pytorch#157216
- fused optimizer for dp2ep EP is unblocked by
pytorch/pytorch#157682

This PR also fixes the issue between auxiliary-loss-free load balancing
and gradient accumulation, partly inspired by the solution of @hann-wang
in pytorch#1304 which originally
pointed out the issue. This PR does the expert bias update in an
optimizer hook, instead of adding another entry in `TrainSpec`.

While working on this PR, I also identified numerical issues between
AdamW and Tensor Parallel, which I will post in a separate issue to
track.


**What is dp2ep Expert Parallel**

Here are two diagrams illustrating the communication / computation
pattern happening in dp2ep Expert Parallel. Basically, the Expert
Parallel degree needed for MoE routed experts is borrowed from the Data
Parallel (including Context Parallel) degree for non-MoE params (e.g.
Attention layers, MLP layers) and other params in MoE layers (including
the router's gate and shared experts).

without TP

![image](https://github.com/user-attachments/assets/fa4f6d42-8885-4536-b887-6234f7b4c638)

with TP

![image](https://github.com/user-attachments/assets/1ee35414-2e07-4d57-952b-cdfaeec0b494)

**Note:** In the current implementation, the all-to-all communication
across all TP ranks are duplicate, causing unnecessary communication
overhead. As the next step, I'm going to implement the "Sequence
Parallel" for the all-to-all, reducing the communication volume to `1 /
tp_degree`.


**Design**

The EP utilizes DTensor's
[`parallelize_module`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/api.py#L16)
API to shard MoE routed experts on the `num_expert` dimension, and
inserts a pair of hooks before and after forward to perform all-to-all
collectives.

In additional, this PR creates an `expert_parallel` wrapper applied to
the GroupedExperts computation, serving
the following three purposes:
1. Convert parameters from DTensors to plain Tensors, to work with
dynamic-shape inputs which cannot be easily expressed as DTensors.
2. In Expert Parallel, apply the `generate_permute_indices` kernel to
permute the inputs to be ordered by local experts (see the
`_token_dispatch` function in `ExpertParallel`) and permute the outputs
back.
3. In order to use `torch._grouped_mm`, we need to make sure the number
of tokens each expert gets is a multiple of `ALIGN_SIZE_M`. The
`generate_permute_indices` kernel also helps achieve this via padding,
without incurring synchronization between device and host. Note that
this will create side effects when wrapping the for-loop implementation
of GroupedExperts, as it does not need padding.
4. Among the above:
    - 1 and 2 are needed only when `expert_parallel_degree` > 1.
    - 3 is needed even for single-device computation.
- 2 can be moved to `ExpertParallel`'s `_token_dispatch` if not coupled
with 3.

Due to the inhomogeneity of `DeviceMesh`es from EP parameters and non-EP
parameters, this PR adds the following special treatment to enable TP
- `DeviceMesh` creation: when EP is enabled, create a special
`DeviceMesh` to share between DP/CP (for non-EP parameters) and EP (for
EP parameters).
- gradient norm clipping: when EP is enabled, separately compute the
norm of EP parameters and non-EP parameters -> compute the global norm
-> separately perform grad norm clipping with the global norm.
- ~~fused optimizer step: created a new optimizer container class
`ExpertParallelOptimizersContainer` which does fused optimizer steps on
EP parameters and non-EP parameters separately.~~ (tackled in
pytorch/pytorch#157682)

For `DeviceMesh`, we'll need to improve the way we can express
non-homogeneous meshes. For gradient norm clipping ~~and fused
optimizer~~, since there are up two groups of parameters, I expect the
approach to be fine, until we find better way of support. Things could
change if LLM / MoE architecture evolves to be more dynamic.


**Communication Trace Verification**


![image](https://github.com/user-attachments/assets/68182c67-91ad-41df-b46a-1fff0b5a6f48)

One can see that in order to call EP all-to-all `_token_dispatch` and
`_token_combine` with correct `input_splits` and `output_splits`, we
need to generate the size data via another `dist.all_to_all_single` (in
the default stream) and do a **device-to-host sync**. This can be
avoided by utilizing SymmetricMemory-based `all-to-all-v`, which we will
work on soon.


**DCP Resharding Correctness and Numerical Verification**

Note: I used `--optimizer.name="Adam"` instead of `"AdamW"` which seems
to cause numerical issues when TP is enabled.

To verify, I created a seed checkpoint of the debug model, fixed the
seed, and ran the same training under different parallelism configs for
100 steps on at most 8 GPUs
- FSDP 2
- FSDP 2 (EP 2), TP 2, PP 2
- HSDP 4 (DP 2, CP 2, EP 4), TP 2

<img width="1317" alt="image"
src="https://github.com/user-attachments/assets/609f057c-0e6a-430a-89dc-5f2070ecb135"
/>


**Next Steps**
- Sequence Parallel for all-to-all communication collectives, when TP is
enabled (at the cost of another pair of TP all-gather and
reduce-scatter)
- adopt SymmetricMemory-based all-to-all and avoid D2H syncs (cc
@kwen2501)
- enable EP in torchtitan's DeepSeekV3 @wwwjn 
- FSDP2 non-dim-0 sharding (cc @weifengpy)
- `torch.compile` support @xmfan
  - which blocks torchao quantization enablement
- computation / communication overlapping
- either via inductor passes to overlap all-to-all with shared expert
computation @xmfan
- or via fine-grained Pipeline Parallel splitting & scheduling @H-Huang
- float8 + MoE TP integration @danielvegamyhre
- Previously float8 works with TP by having specialized
`ColwiseParallel` and `RowwiseParallel` (see
[code](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L167)).
For MoE, I'm creating new ad hoc `ParallelStyle`s, including
`TensorParallel`, `ExpertParallel`, and `ExpertTensorParallel`.
- better `DeviceMesh` support and general "ETP" support (where experts
TP and attention/mlp TP don't have to have the same TP degree) @fduwjj
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 9, 2025
)

This is to unblock "dp2ep" Expert Parallel + TP integration in torchtitan pytorch/torchtitan#1324.

It does two things:
1. Slightly modifies the glue code for FSDP/HSDP + TP to work with FSDP/HSDP + EP and FSDP/HSDP + EP + TP. I kept the name `FSDPParam._tp_spec` to make the change minimal. We can consider renaming it in the future if it confuses people, but I heard @wanchaol has a plan to rewrite DTensor strided sharding entirely.
2. Lifts the check of `_validate_tp_mesh_dim` for `torch.distributed.tensor.parallel.parallelize_module`, as in EP or EP+TP this check is too strict. In particular it assumes a DeviceMesh must have `mesh_dim_names` which is not always true. I'm also removing the file `torch/distributed/tensor/parallel/_utils.py` it belongs entirely, as the other check `_deprecate_warnings`, added two years ago, is not used any more.

Pull Request resolved: #157216
Approved by: https://github.com/wanchaol, https://github.com/weifengpy
wwwjn added a commit that referenced this pull request Jul 10, 2025
Mostly adapted from llama4, change the TP plan based on the difference
between deepseek-v3 and llama.

Thanks @tianyu-l for the detailed walk through about deepseek-v3
attention model and TP plan! This diff is currently based on #1324 , and
we want to extract the MoE model in DSV3 and llama4 in a shared place.

Now we have:
1. FSDP
2. Activation Checkpointing
3. TP
4. CP in progress (hang due to some reason)

1. Make CP work

There are minor issue with the numerical verification: With
deterministic seed, the loss is not identical. I used `AdamW` optimizer.

1. FSDP degree=4 (blue line)
2. FSDP degree=4, TP degree = 2 (orange line)

<img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM"
src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9"
/>

With `Adam` optimizer, the loss is **exactly the same**:
<img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM"
src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb"
/>

---------

Co-authored-by: Tianyu Liu <lty@fb.com>
tianyu-l added a commit that referenced this pull request Jul 11, 2025
## Supported Features
- FSDP, HSDP
- Activation checkpointing
- Tensor Parallel (TP) from @tianyu-l 
- Expert Parallel (EP)


## To be added
- Modeling
    - Merge DeepSeek-V3 and Llama4 MoE common components
- Parallelism
    - Context Parallel support for DeepSeek-V3
    - PP support for DeepSeek-V3 @H-Huang  is working on #1345 
- torch.compile
- Quantization
- Testing
    - perfomance and loss converging tests
- CI integration - @wwwjn will work on this after PyTorch side diffs
(mentioned in #1324) get into PyTorch nightly


## Test
1. With FSDP=8, EP=2 (['dp_shard_mod_ep', 'dp_shard_in_ep'], [4, 2])
```
[rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - step:  1  loss: 12.2616  grad_norm:  0.3918  memory: 65.53GiB(68.98%)  tps: 1,482  tflops: 0.61  mfu: 0.06%
[rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-08 15:15:43,543 - root - INFO - step:  2  loss: 12.0093  grad_norm:  0.5745  memory: 65.54GiB(68.99%)  tps: 69,111  tflops: 28.68  mfu: 2.90%
[rank0]:[titan] 2025-07-08 15:15:43,981 - root - INFO - step:  3  loss: 11.1697  grad_norm:  1.2095  memory: 65.54GiB(68.99%)  tps: 74,931  tflops: 31.09  mfu: 3.14%
[rank0]:[titan] 2025-07-08 15:15:44,015 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-07-08 15:15:44,409 - root - INFO - step:  4  loss: 10.7248  grad_norm:  1.2230  memory: 65.54GiB(68.99%)  tps: 76,668  tflops: 31.81  mfu: 3.22%
[rank0]:[titan] 2025-07-08 15:15:44,838 - root - INFO - step:  5  loss: 10.5484  grad_norm:  1.1633  memory: 65.54GiB(68.99%)  tps: 76,416  tflops: 31.71  mfu: 3.21%
[rank0]:[titan] 2025-07-08 15:15:45,339 - root - INFO - step:  6  loss: 10.3509  grad_norm:  1.1611  memory: 65.54GiB(68.99%)  tps: 65,490  tflops: 27.18  mfu: 2.75%
[rank0]:[titan] 2025-07-08 15:15:45,401 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-07-08 15:15:46,121 - root - INFO - step:  7  loss: 10.2153  grad_norm:  1.1410  memory: 65.54GiB(68.99%)  tps: 41,934  tflops: 17.40  mfu: 1.76%
[rank0]:[titan] 2025-07-08 15:15:46,733 - root - INFO - step:  8  loss: 10.0801  grad_norm:  1.1487  memory: 65.54GiB(68.99%)  tps: 53,599  tflops: 22.24  mfu: 2.25%
[rank0]:[titan] 2025-07-08 15:15:47,137 - root - INFO - step:  9  loss:  9.9781  grad_norm:  1.1257  memory: 65.54GiB(68.99%)  tps: 81,051  tflops: 33.63  mfu: 3.40%
[rank0]:[titan] 2025-07-08 15:15:47,554 - root - INFO - step: 10  loss:  9.9183  grad_norm:  1.1012  memory: 65.54GiB(68.99%)  tps: 78,712  tflops: 32.66  mfu: 3.30%
```

2. With FSDP=4, TP=2
```
[rank0]:[titan] 2025-07-08 15:16:25,927 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - step:  1  loss: 12.2768  grad_norm:  0.3836  memory: 41.14GiB(43.31%)  tps: 1,750  tflops: 0.73  mfu: 0.07%
[rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-08 15:16:35,310 - root - INFO - step:  2  loss: 12.0284  grad_norm:  0.5423  memory: 41.29GiB(43.46%)  tps: 51,796  tflops: 21.49  mfu: 2.17%
[rank0]:[titan] 2025-07-08 15:16:35,605 - root - INFO - step:  3  loss: 11.2398  grad_norm:  1.2037  memory: 41.29GiB(43.46%)  tps: 55,575  tflops: 23.06  mfu: 2.33%
[rank0]:[titan] 2025-07-08 15:16:35,912 - root - INFO - step:  4  loss: 10.8246  grad_norm:  1.2360  memory: 41.29GiB(43.46%)  tps: 53,553  tflops: 22.22  mfu: 2.25%
[rank0]:[titan] 2025-07-08 15:16:36,206 - root - INFO - step:  5  loss: 10.6295  grad_norm:  1.1951  memory: 41.29GiB(43.46%)  tps: 55,732  tflops: 23.13  mfu: 2.34%
[rank0]:[titan] 2025-07-08 15:16:36,502 - root - INFO - step:  6  loss: 10.5240  grad_norm:  1.1296  memory: 41.29GiB(43.46%)  tps: 55,564  tflops: 23.06  mfu: 2.33%
[rank0]:[titan] 2025-07-08 15:16:36,793 - root - INFO - step:  7  loss: 10.3426  grad_norm:  1.1630  memory: 41.29GiB(43.46%)  tps: 56,295  tflops: 23.36  mfu: 2.36%
[rank0]:[titan] 2025-07-08 15:16:36,824 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-07-08 15:16:37,081 - root - INFO - step:  8  loss: 10.2127  grad_norm:  1.1499  memory: 41.29GiB(43.46%)  tps: 57,052  tflops: 23.67  mfu: 2.39%
[rank0]:[titan] 2025-07-08 15:16:37,374 - root - INFO - step:  9  loss: 10.0537  grad_norm:  1.1814  memory: 41.29GiB(43.46%)  tps: 56,019  tflops: 23.25  mfu: 2.35%
[rank0]:[titan] 2025-07-08 15:16:37,664 - root - INFO - step: 10  loss: 10.0311  grad_norm:  1.1082  memory: 41.29GiB(43.46%)  tps: 56,504  tflops: 23.45  mfu: 2.37%
```

---------

Co-authored-by: Tianyu Liu <lty@fb.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants