Skip to content

[DSv3] Compile support for single GPU #1364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

xmfan
Copy link
Member

@xmfan xmfan commented Jul 2, 2025

Tested on Single GPU repro NGPU=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" tlp ./run_train.sh --training.compile.

Eager: https://gist.github.com/xmfan/6ae0f9db6ebb5b1809f136a2f46f55d3
Compile: https://gist.github.com/xmfan/30ed647ddda29531e412bd9ad800d503
17.4% less memory (8.51GiB vs 10.31GiB eager)
1.063x speedup (22.97 vs 21.59 eager)

  1. Fullgraph support for non-MoE layers, which worked out of box.
  2. MoE with graph breaks. There were a few issues surrounding inactive experts using zero-shaped tensors causing recompilation and dynamic shapes errors. Also, there graph break due to the full_backward_hook requirement, but I hear this may be removed soon.

I'm proposing to rewrite the MoE expert token splitting function in order to hide the zero-shaped expert tokens from the compiler, and to explicitly ask the compiler to treat the expert tokens shapes as dynamic.

No obvious wrong things with the trace: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/xmfan/1ed53a34-a2b4-4eb1-8bb4-738b62be0bfa/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=1000.

cc @wwwjn @tianyu-l

wwwjn and others added 5 commits July 1, 2025 16:05
## Contents
1. Attention module
2. MoE module (note: I only implemented the naive routing, not the "node
limit routing" strategy)
3. Deepseek-V3 model

Reference:
1. Deepseek-ai:
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
4. Huggingface:
https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
5. torchtitan/experiment/deepseek-v3
6. torchtitan/experiment/llama4

## TODO
- [ ] Further clean up the DeepseekV3ModelArgs class, remove unused
model args
- [ ] Test forward pass w/ torchtitan
Command to run: `NGPU=1
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml"
./run_train.sh`

## Context
1. Added model args for 4 model settings, and training config for debug
model
2. Debugged the forward pass, and the backward pass works out of pocket.
3. Reused c4-test dataset, and tiktokenizer from llama3 model for
current testing

![Screenshot 2025-06-20 at 11 52
49 AM](https://github.com/user-attachments/assets/81d938a2-9a85-4e8c-b8e1-7f9510d785c2)
…6B model (pytorch#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 2, 2025
@tianyu-l
Copy link
Contributor

tianyu-l commented Jul 2, 2025

Thanks, @xmfan!

Also, there graph break due to the full_backward_hook requirement, but I hear this may be removed soon.

Although dsv3 is receiving more interests today, the MoE and EP part will follow what I developed in llama4 #1324. Over there I have removed the backward hooks, and make several other adjustments in the MoE layers and EP. I'd suggest we work with that version first, since it sounds the non-MoE layers in dsv3 have been de-risked from compile perspective?

I'm proposing to rewrite the MoE expert token splitting function in order to hide the zero-shaped expert tokens from the compiler, and to explicitly ask the compiler to treat the expert tokens shapes as dynamic.

This sounds interesting. Does it remove all graph breaks (excluding the backward hook ones)?

@wwwjn
Copy link
Contributor

wwwjn commented Jul 2, 2025

@xmfan I updated the base branch (deepseek-v3) after the tokenizer diff (#1318) landed, you could update your diff accordingly. Thank you!

@xmfan
Copy link
Member Author

xmfan commented Jul 2, 2025

I'm proposing to rewrite the MoE expert token splitting function in order to hide the zero-shaped expert tokens from the compiler, and to explicitly ask the compiler to treat the expert tokens shapes as dynamic.

This sounds interesting. Does it remove all graph breaks (excluding the backward hook ones)?

No, it's moving all compiler unfriendly logic into a small function and manually graph breaking on it. So, a workaround to prevent the compiler from graph breaking at 2 awkward spots:

  • different zero-shaped tensors depending on which expert is inactive: will recompile every time
  • tensor_split: will always graph break unless we turn on capture_scalar_outputs. but if we do, it will runs into unbacked symint problems again.

This PR just adds the initial support to identify all the issues. And it's just this tokens split logic that's problematic.

I'd suggest we work with that version first, since it sounds the non-MoE layers in dsv3 have been de-risked from compile perspective?

Yes, there were no issues with fullgraphing the regular feed forward transformers blocks, the only issues should be about the MoE token splitting.

I will close this PR then, and focus on making changes to the llama4 MoE, since it looks like DSv3 doesn't have errors specific to it.

@xmfan xmfan closed this Jul 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants