Skip to content

[moe training] Add TP support for routed experts #2473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2025
Merged

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 2, 2025

Stack

Summary

  • Adds TP integration test and testing utils file for shared functions used in single GPU / FSDP / TP training tests
  • Update _scaled_grouped_mm to support 3D A tensor, which is needed for the shared expert (code).
  • Make offs optional, to handle shared_expert case where num_experts=1 (no group offsets needed since there's only 1 token group).

Test plan

  • ./test/prototype/moe_training/test_tp.sh

@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 2, 2025
Copy link

pytorch-bot bot commented Jul 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2473

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending

As of commit b9e58fe with merge base 01f7352 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 2, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft July 2, 2025 01:45
@danielvegamyhre danielvegamyhre marked this pull request as ready for review July 2, 2025 02:08
@danielvegamyhre danielvegamyhre requested review from vkuzo and drisspg July 2, 2025 03:18
from torch.nn import functional as F

# this feature requires CUDA and SM89+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the bf16 group gemm working on A100? If yes, I would vote for adding an emulation mode and running this test in emulation mode. We do this for float8 and MX training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into this and add emulation if bf16 grouped gemm builds on a100

from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor


def _validate_model_conversion(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why do we need recursion to check this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a generic way of checking all target FQNs were converted properly, and verifying all non-target FQNs were correctly not converted. It can easily be applied when we extend tests to other MoE models as well beyond just the torchtitan llama4 one I started with.

@danielvegamyhre danielvegamyhre force-pushed the dtype branch 2 times, most recently from 105689f to bb9626e Compare July 2, 2025 16:13
@danielvegamyhre danielvegamyhre force-pushed the tp branch 2 times, most recently from cb1eae9 to 7fed93e Compare July 2, 2025 16:48
@danielvegamyhre danielvegamyhre changed the base branch from dtype to main July 2, 2025 17:19
@danielvegamyhre danielvegamyhre force-pushed the tp branch 3 times, most recently from e92f92d to 04a3d2f Compare July 2, 2025 17:29
@danielvegamyhre danielvegamyhre merged commit 6821971 into main Jul 2, 2025
18 of 19 checks passed
tianyu-l pushed a commit to pytorch/torchtitan that referenced this pull request Jul 9, 2025
We can remove this assertion, TP support for float8 rowwise MoE training
was added in this PR stack: pytorch/ao#2473
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants