Skip to content

[llama4][auxiliary-loss-free load balancing] update expert_bias without backward hooks #1304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

hann-wang
Copy link
Contributor

Changes:

Reasons:

  • Friendly for torch.compile and activation checkpointing.
  • The original implementation updates expert_bias on each microbatches during gradient accumulation.

cc @tianyu-l

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 16, 2025
@tianyu-l tianyu-l mentioned this pull request Jun 29, 2025
tianyu-l added a commit that referenced this pull request Jul 8, 2025
**Overview**

Previously I demonstrated Expert Parallel for expert-choice MoE in a
stack of PRs #732.

This PR adds the initial support of dp2ep Expert Parallel for
token-choice MoE, being non-intrusive to model code and composable with
other parallelisms. In particular:
- FSDP/HSDP + TP + EP is unblocked by
pytorch/pytorch#157216
- fused optimizer for dp2ep EP is unblocked by
pytorch/pytorch#157682

This PR also fixes the issue between auxiliary-loss-free load balancing
and gradient accumulation, partly inspired by the solution of @hann-wang
in #1304 which originally
pointed out the issue. This PR does the expert bias update in an
optimizer hook, instead of adding another entry in `TrainSpec`.

While working on this PR, I also identified numerical issues between
AdamW and Tensor Parallel, which I will post in a separate issue to
track.


**What is dp2ep Expert Parallel**

Here are two diagrams illustrating the communication / computation
pattern happening in dp2ep Expert Parallel. Basically, the Expert
Parallel degree needed for MoE routed experts is borrowed from the Data
Parallel (including Context Parallel) degree for non-MoE params (e.g.
Attention layers, MLP layers) and other params in MoE layers (including
the router's gate and shared experts).

without TP

![image](https://github.com/user-attachments/assets/fa4f6d42-8885-4536-b887-6234f7b4c638)

with TP

![image](https://github.com/user-attachments/assets/1ee35414-2e07-4d57-952b-cdfaeec0b494)

**Note:** In the current implementation, the all-to-all communication
across all TP ranks are duplicate, causing unnecessary communication
overhead. As the next step, I'm going to implement the "Sequence
Parallel" for the all-to-all, reducing the communication volume to `1 /
tp_degree`.


**Design**

The EP utilizes DTensor's
[`parallelize_module`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/api.py#L16)
API to shard MoE routed experts on the `num_expert` dimension, and
inserts a pair of hooks before and after forward to perform all-to-all
collectives.

In additional, this PR creates an `expert_parallel` wrapper applied to
the GroupedExperts computation, serving
the following three purposes:
1. Convert parameters from DTensors to plain Tensors, to work with
dynamic-shape inputs which cannot be easily expressed as DTensors.
2. In Expert Parallel, apply the `generate_permute_indices` kernel to
permute the inputs to be ordered by local experts (see the
`_token_dispatch` function in `ExpertParallel`) and permute the outputs
back.
3. In order to use `torch._grouped_mm`, we need to make sure the number
of tokens each expert gets is a multiple of `ALIGN_SIZE_M`. The
`generate_permute_indices` kernel also helps achieve this via padding,
without incurring synchronization between device and host. Note that
this will create side effects when wrapping the for-loop implementation
of GroupedExperts, as it does not need padding.
4. Among the above:
    - 1 and 2 are needed only when `expert_parallel_degree` > 1.
    - 3 is needed even for single-device computation.
- 2 can be moved to `ExpertParallel`'s `_token_dispatch` if not coupled
with 3.

Due to the inhomogeneity of `DeviceMesh`es from EP parameters and non-EP
parameters, this PR adds the following special treatment to enable TP
- `DeviceMesh` creation: when EP is enabled, create a special
`DeviceMesh` to share between DP/CP (for non-EP parameters) and EP (for
EP parameters).
- gradient norm clipping: when EP is enabled, separately compute the
norm of EP parameters and non-EP parameters -> compute the global norm
-> separately perform grad norm clipping with the global norm.
- ~~fused optimizer step: created a new optimizer container class
`ExpertParallelOptimizersContainer` which does fused optimizer steps on
EP parameters and non-EP parameters separately.~~ (tackled in
pytorch/pytorch#157682)

For `DeviceMesh`, we'll need to improve the way we can express
non-homogeneous meshes. For gradient norm clipping ~~and fused
optimizer~~, since there are up two groups of parameters, I expect the
approach to be fine, until we find better way of support. Things could
change if LLM / MoE architecture evolves to be more dynamic.


**Communication Trace Verification**


![image](https://github.com/user-attachments/assets/68182c67-91ad-41df-b46a-1fff0b5a6f48)

One can see that in order to call EP all-to-all `_token_dispatch` and
`_token_combine` with correct `input_splits` and `output_splits`, we
need to generate the size data via another `dist.all_to_all_single` (in
the default stream) and do a **device-to-host sync**. This can be
avoided by utilizing SymmetricMemory-based `all-to-all-v`, which we will
work on soon.


**DCP Resharding Correctness and Numerical Verification**

Note: I used `--optimizer.name="Adam"` instead of `"AdamW"` which seems
to cause numerical issues when TP is enabled.

To verify, I created a seed checkpoint of the debug model, fixed the
seed, and ran the same training under different parallelism configs for
100 steps on at most 8 GPUs
- FSDP 2
- FSDP 2 (EP 2), TP 2, PP 2
- HSDP 4 (DP 2, CP 2, EP 4), TP 2

<img width="1317" alt="image"
src="https://github.com/user-attachments/assets/609f057c-0e6a-430a-89dc-5f2070ecb135"
/>


**Next Steps**
- Sequence Parallel for all-to-all communication collectives, when TP is
enabled (at the cost of another pair of TP all-gather and
reduce-scatter)
- adopt SymmetricMemory-based all-to-all and avoid D2H syncs (cc
@kwen2501)
- enable EP in torchtitan's DeepSeekV3 @wwwjn 
- FSDP2 non-dim-0 sharding (cc @weifengpy)
- `torch.compile` support @xmfan
  - which blocks torchao quantization enablement
- computation / communication overlapping
- either via inductor passes to overlap all-to-all with shared expert
computation @xmfan
- or via fine-grained Pipeline Parallel splitting & scheduling @H-Huang
- float8 + MoE TP integration @danielvegamyhre
- Previously float8 works with TP by having specialized
`ColwiseParallel` and `RowwiseParallel` (see
[code](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L167)).
For MoE, I'm creating new ad hoc `ParallelStyle`s, including
`TensorParallel`, `ExpertParallel`, and `ExpertTensorParallel`.
- better `DeviceMesh` support and general "ETP" support (where experts
TP and attention/mlp TP don't have to have the same TP degree) @fduwjj
mori360 pushed a commit to mori360/torchtitan that referenced this pull request Jul 8, 2025
**Overview**

Previously I demonstrated Expert Parallel for expert-choice MoE in a
stack of PRs pytorch#732.

This PR adds the initial support of dp2ep Expert Parallel for
token-choice MoE, being non-intrusive to model code and composable with
other parallelisms. In particular:
- FSDP/HSDP + TP + EP is unblocked by
pytorch/pytorch#157216
- fused optimizer for dp2ep EP is unblocked by
pytorch/pytorch#157682

This PR also fixes the issue between auxiliary-loss-free load balancing
and gradient accumulation, partly inspired by the solution of @hann-wang
in pytorch#1304 which originally
pointed out the issue. This PR does the expert bias update in an
optimizer hook, instead of adding another entry in `TrainSpec`.

While working on this PR, I also identified numerical issues between
AdamW and Tensor Parallel, which I will post in a separate issue to
track.


**What is dp2ep Expert Parallel**

Here are two diagrams illustrating the communication / computation
pattern happening in dp2ep Expert Parallel. Basically, the Expert
Parallel degree needed for MoE routed experts is borrowed from the Data
Parallel (including Context Parallel) degree for non-MoE params (e.g.
Attention layers, MLP layers) and other params in MoE layers (including
the router's gate and shared experts).

without TP

![image](https://github.com/user-attachments/assets/fa4f6d42-8885-4536-b887-6234f7b4c638)

with TP

![image](https://github.com/user-attachments/assets/1ee35414-2e07-4d57-952b-cdfaeec0b494)

**Note:** In the current implementation, the all-to-all communication
across all TP ranks are duplicate, causing unnecessary communication
overhead. As the next step, I'm going to implement the "Sequence
Parallel" for the all-to-all, reducing the communication volume to `1 /
tp_degree`.


**Design**

The EP utilizes DTensor's
[`parallelize_module`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/api.py#L16)
API to shard MoE routed experts on the `num_expert` dimension, and
inserts a pair of hooks before and after forward to perform all-to-all
collectives.

In additional, this PR creates an `expert_parallel` wrapper applied to
the GroupedExperts computation, serving
the following three purposes:
1. Convert parameters from DTensors to plain Tensors, to work with
dynamic-shape inputs which cannot be easily expressed as DTensors.
2. In Expert Parallel, apply the `generate_permute_indices` kernel to
permute the inputs to be ordered by local experts (see the
`_token_dispatch` function in `ExpertParallel`) and permute the outputs
back.
3. In order to use `torch._grouped_mm`, we need to make sure the number
of tokens each expert gets is a multiple of `ALIGN_SIZE_M`. The
`generate_permute_indices` kernel also helps achieve this via padding,
without incurring synchronization between device and host. Note that
this will create side effects when wrapping the for-loop implementation
of GroupedExperts, as it does not need padding.
4. Among the above:
    - 1 and 2 are needed only when `expert_parallel_degree` > 1.
    - 3 is needed even for single-device computation.
- 2 can be moved to `ExpertParallel`'s `_token_dispatch` if not coupled
with 3.

Due to the inhomogeneity of `DeviceMesh`es from EP parameters and non-EP
parameters, this PR adds the following special treatment to enable TP
- `DeviceMesh` creation: when EP is enabled, create a special
`DeviceMesh` to share between DP/CP (for non-EP parameters) and EP (for
EP parameters).
- gradient norm clipping: when EP is enabled, separately compute the
norm of EP parameters and non-EP parameters -> compute the global norm
-> separately perform grad norm clipping with the global norm.
- ~~fused optimizer step: created a new optimizer container class
`ExpertParallelOptimizersContainer` which does fused optimizer steps on
EP parameters and non-EP parameters separately.~~ (tackled in
pytorch/pytorch#157682)

For `DeviceMesh`, we'll need to improve the way we can express
non-homogeneous meshes. For gradient norm clipping ~~and fused
optimizer~~, since there are up two groups of parameters, I expect the
approach to be fine, until we find better way of support. Things could
change if LLM / MoE architecture evolves to be more dynamic.


**Communication Trace Verification**


![image](https://github.com/user-attachments/assets/68182c67-91ad-41df-b46a-1fff0b5a6f48)

One can see that in order to call EP all-to-all `_token_dispatch` and
`_token_combine` with correct `input_splits` and `output_splits`, we
need to generate the size data via another `dist.all_to_all_single` (in
the default stream) and do a **device-to-host sync**. This can be
avoided by utilizing SymmetricMemory-based `all-to-all-v`, which we will
work on soon.


**DCP Resharding Correctness and Numerical Verification**

Note: I used `--optimizer.name="Adam"` instead of `"AdamW"` which seems
to cause numerical issues when TP is enabled.

To verify, I created a seed checkpoint of the debug model, fixed the
seed, and ran the same training under different parallelism configs for
100 steps on at most 8 GPUs
- FSDP 2
- FSDP 2 (EP 2), TP 2, PP 2
- HSDP 4 (DP 2, CP 2, EP 4), TP 2

<img width="1317" alt="image"
src="https://github.com/user-attachments/assets/609f057c-0e6a-430a-89dc-5f2070ecb135"
/>


**Next Steps**
- Sequence Parallel for all-to-all communication collectives, when TP is
enabled (at the cost of another pair of TP all-gather and
reduce-scatter)
- adopt SymmetricMemory-based all-to-all and avoid D2H syncs (cc
@kwen2501)
- enable EP in torchtitan's DeepSeekV3 @wwwjn 
- FSDP2 non-dim-0 sharding (cc @weifengpy)
- `torch.compile` support @xmfan
  - which blocks torchao quantization enablement
- computation / communication overlapping
- either via inductor passes to overlap all-to-all with shared expert
computation @xmfan
- or via fine-grained Pipeline Parallel splitting & scheduling @H-Huang
- float8 + MoE TP integration @danielvegamyhre
- Previously float8 works with TP by having specialized
`ColwiseParallel` and `RowwiseParallel` (see
[code](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L167)).
For MoE, I'm creating new ad hoc `ParallelStyle`s, including
`TensorParallel`, `ExpertParallel`, and `ExpertTensorParallel`.
- better `DeviceMesh` support and general "ETP" support (where experts
TP and attention/mlp TP don't have to have the same TP degree) @fduwjj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants