Skip to content

[DSV3] Adding deepseek-v3 model into torchtitan #1373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 11, 2025
Merged

[DSV3] Adding deepseek-v3 model into torchtitan #1373

merged 14 commits into from
Jul 11, 2025

Conversation

wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Jul 8, 2025

Supported Features

  • FSDP, HSDP
  • Activation checkpointing
  • Tensor Parallel (TP) from @tianyu-l
  • Expert Parallel (EP)

To be added

  • Modeling
    • Merge DeepSeek-V3 and Llama4 MoE common components
  • Parallelism
  • torch.compile
  • Quantization
  • Testing
    • perfomance and loss converging tests
    • CI integration - @wwwjn will work on this after PyTorch side diffs (mentioned in dp2ep Expert Parallel #1324) get into PyTorch nightly

Test

  1. With FSDP=8, EP=2 (['dp_shard_mod_ep', 'dp_shard_in_ep'], [4, 2])
[rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - step:  1  loss: 12.2616  grad_norm:  0.3918  memory: 65.53GiB(68.98%)  tps: 1,482  tflops: 0.61  mfu: 0.06%
[rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-08 15:15:43,543 - root - INFO - step:  2  loss: 12.0093  grad_norm:  0.5745  memory: 65.54GiB(68.99%)  tps: 69,111  tflops: 28.68  mfu: 2.90%
[rank0]:[titan] 2025-07-08 15:15:43,981 - root - INFO - step:  3  loss: 11.1697  grad_norm:  1.2095  memory: 65.54GiB(68.99%)  tps: 74,931  tflops: 31.09  mfu: 3.14%
[rank0]:[titan] 2025-07-08 15:15:44,015 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-07-08 15:15:44,409 - root - INFO - step:  4  loss: 10.7248  grad_norm:  1.2230  memory: 65.54GiB(68.99%)  tps: 76,668  tflops: 31.81  mfu: 3.22%
[rank0]:[titan] 2025-07-08 15:15:44,838 - root - INFO - step:  5  loss: 10.5484  grad_norm:  1.1633  memory: 65.54GiB(68.99%)  tps: 76,416  tflops: 31.71  mfu: 3.21%
[rank0]:[titan] 2025-07-08 15:15:45,339 - root - INFO - step:  6  loss: 10.3509  grad_norm:  1.1611  memory: 65.54GiB(68.99%)  tps: 65,490  tflops: 27.18  mfu: 2.75%
[rank0]:[titan] 2025-07-08 15:15:45,401 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-07-08 15:15:46,121 - root - INFO - step:  7  loss: 10.2153  grad_norm:  1.1410  memory: 65.54GiB(68.99%)  tps: 41,934  tflops: 17.40  mfu: 1.76%
[rank0]:[titan] 2025-07-08 15:15:46,733 - root - INFO - step:  8  loss: 10.0801  grad_norm:  1.1487  memory: 65.54GiB(68.99%)  tps: 53,599  tflops: 22.24  mfu: 2.25%
[rank0]:[titan] 2025-07-08 15:15:47,137 - root - INFO - step:  9  loss:  9.9781  grad_norm:  1.1257  memory: 65.54GiB(68.99%)  tps: 81,051  tflops: 33.63  mfu: 3.40%
[rank0]:[titan] 2025-07-08 15:15:47,554 - root - INFO - step: 10  loss:  9.9183  grad_norm:  1.1012  memory: 65.54GiB(68.99%)  tps: 78,712  tflops: 32.66  mfu: 3.30%
  1. With FSDP=4, TP=2
[rank0]:[titan] 2025-07-08 15:16:25,927 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - step:  1  loss: 12.2768  grad_norm:  0.3836  memory: 41.14GiB(43.31%)  tps: 1,750  tflops: 0.73  mfu: 0.07%
[rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-08 15:16:35,310 - root - INFO - step:  2  loss: 12.0284  grad_norm:  0.5423  memory: 41.29GiB(43.46%)  tps: 51,796  tflops: 21.49  mfu: 2.17%
[rank0]:[titan] 2025-07-08 15:16:35,605 - root - INFO - step:  3  loss: 11.2398  grad_norm:  1.2037  memory: 41.29GiB(43.46%)  tps: 55,575  tflops: 23.06  mfu: 2.33%
[rank0]:[titan] 2025-07-08 15:16:35,912 - root - INFO - step:  4  loss: 10.8246  grad_norm:  1.2360  memory: 41.29GiB(43.46%)  tps: 53,553  tflops: 22.22  mfu: 2.25%
[rank0]:[titan] 2025-07-08 15:16:36,206 - root - INFO - step:  5  loss: 10.6295  grad_norm:  1.1951  memory: 41.29GiB(43.46%)  tps: 55,732  tflops: 23.13  mfu: 2.34%
[rank0]:[titan] 2025-07-08 15:16:36,502 - root - INFO - step:  6  loss: 10.5240  grad_norm:  1.1296  memory: 41.29GiB(43.46%)  tps: 55,564  tflops: 23.06  mfu: 2.33%
[rank0]:[titan] 2025-07-08 15:16:36,793 - root - INFO - step:  7  loss: 10.3426  grad_norm:  1.1630  memory: 41.29GiB(43.46%)  tps: 56,295  tflops: 23.36  mfu: 2.36%
[rank0]:[titan] 2025-07-08 15:16:36,824 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-07-08 15:16:37,081 - root - INFO - step:  8  loss: 10.2127  grad_norm:  1.1499  memory: 41.29GiB(43.46%)  tps: 57,052  tflops: 23.67  mfu: 2.39%
[rank0]:[titan] 2025-07-08 15:16:37,374 - root - INFO - step:  9  loss: 10.0537  grad_norm:  1.1814  memory: 41.29GiB(43.46%)  tps: 56,019  tflops: 23.25  mfu: 2.35%
[rank0]:[titan] 2025-07-08 15:16:37,664 - root - INFO - step: 10  loss: 10.0311  grad_norm:  1.1082  memory: 41.29GiB(43.46%)  tps: 56,504  tflops: 23.45  mfu: 2.37%

@wwwjn wwwjn requested review from tianyu-l, fegin and wconstab as code owners July 8, 2025 18:59
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 8, 2025
@wwwjn wwwjn requested a review from H-Huang July 8, 2025 18:59
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@bloc97
Copy link

bloc97 commented Jul 8, 2025

It's possible that k_pe.expand(-1, -1, n_local_heads, -1) needs to be a custom tensor parallel operation with a torch.sum(grad.unshard(), dim=2) in the backward pass, otherwise the sharded gradients from k won't flow back properly into the unsharded k_pe gradients.

We actually encountered this issue when implementing TP with DeepSeek v3 models (ccing @jquesnelle who wrote most of the TP impl), and I've made a diagram to help us debug this gradient flow issue, maybe it'll help here too.
image

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
filter_fqns = ["output"]
filter_fqns = ["output", "router.gate"]

@tianyu-l
Copy link
Contributor

tianyu-l commented Jul 9, 2025

@bloc97 Thank you very much for bringing this issue to our concern! We haven't looked closely enough to the issue -- we will!

However, from some initial numerical testing, it seems our TP gives the exact same numerics compared with FSDP
(see PR summary of #1341).

I wonder if you have identified the same issue in our implementation, or it is more like a warning for us to be cautious in doing TP? We appreciate your feedback!

@bloc97
Copy link

bloc97 commented Jul 9, 2025

@bloc97 Thank you very much for bringing this issue to our concern! We haven't looked closely enough to the issue -- we will!

However, from some initial numerical testing, it seems our TP gives the exact same numerics compared with FSDP (see PR summary of #1341).

I wonder if you have identified the same issue in our implementation, or it is more like a warning for us to be cautious in doing TP? We appreciate your feedback!

Because how similar both implementations are, I'm assuming the same problem will show up. Training divergences only showed up in our case after training for more than 1000 steps, which logically makes sense because of how small the contribution from k_pe is.

The forward pass is correct, as k_pe is computed from an unsharded tensor, and expanded into a sharded tensor (doesn't matter because all the shards are identical). However the backward pass is wrong, because it is summing a sharded tensor k (whose gradient is not the same across different shards) into an unsharded tensor k_pe. If you use TP across two gpus, GPU0's k_pe will never see the gradients of GPU1's k_pe, and vice versa!

The best way to verify correctness is to check if the gradients in GPU0 are bit-identical to the gradients in GPU1 when doing TP.

Pointing this out to hopefully prevent the same headache for you guys, this insidious bug was very hard to find.

@wwwjn
Copy link
Contributor Author

wwwjn commented Jul 9, 2025

@bloc97 Thank you very much for bringing this issue to our concern! We haven't looked closely enough to the issue -- we will!
However, from some initial numerical testing, it seems our TP gives the exact same numerics compared with FSDP (see PR summary of #1341).
I wonder if you have identified the same issue in our implementation, or it is more like a warning for us to be cautious in doing TP? We appreciate your feedback!

Because how similar both implementations are, I'm assuming the same problem will show up. Training divergences only showed up in our case after training for more than 1000 steps, which logically makes sense because of how small the contribution from k_pe is.

The forward pass is correct, as k_pe is computed from an unsharded tensor, and expanded into a sharded tensor (doesn't matter because all the shards are identical). However the backward pass is wrong, because it is summing a sharded tensor k (whose gradient is not the same across different shards) into an unsharded tensor k_pe. If you use TP across two gpus, GPU0's k_pe will never see the gradients of GPU1's k_pe, and vice versa!

The best way to verify correctness is to check if the gradients in GPU0 are bit-identical to the gradients in GPU1 when doing TP.

Pointing this out to hopefully prevent the same headache for you guys, this insidious bug was very hard to find.

Thank you so much for pointing this out, that helps a lot! Now I see the problem with expanding k_pe from 1 head into n_local_head, we will fix this soon!

@tianyu-l
Copy link
Contributor

@bloc97 We realized our code has the same issue, and we appreciate your warning a lot!

I think the "root cause" is we convert DTensors to plain Tensors outside the nn.Linear modules, otherwise DTensor could have taken care of the partial gradients coming from different ranks.

The reason we are not using DTensors in between linear modules is because a notorious bug between complex number multiplication and PyTorch Tensor subclass. cc @bdhirsh

In terms of solution, we wanted the code to be in certain style, as autograd functions would break torch.compile full graph support. Let me think more about how we should solve this.

@wanchaol
Copy link
Collaborator

wanchaol commented Jul 10, 2025

I think the "root cause" is we convert DTensors to plain Tensors outside the nn.Linear modules, otherwise DTensor could have taken care of the partial gradients coming from different ranks.
The reason we are not using DTensors in between linear modules is because pytorch/pytorch#130646 between complex number multiplication and PyTorch Tensor subclass.

@tianyu-l @wwwjn I didn't look closely where the issue pops, but assuming this statement is correct, IMO we should really fix the mentioned tensor subclass + complex number bug directly. It seems hurting us for a couple of times already so it is better to fix this in core (I think it would also benefit other tensor subclasses as a whole). cc @albanD @ezyang

@tianyu-l
Copy link
Contributor

@wanchaol
I've mentioned this to @ezyang and @bdhirsh. I also think we really should fix this. It is not only limiting the way we write model code and parallelisms, but also stops us from supporting important downstream use cases.

E.g. with this issue, we can't run Sequence Parallel on uneven sequence length, which means we can't do generation with Sequence Parallel, unless users explicitly handle padding / unpadding themselves. See

Moreover, this is not about adding support for something new; this is about fixing a bug between Tensor subclass and complex numbers (both are important components themselves), without which user could hit silent numerical errors.

@tianyu-l
Copy link
Contributor

tianyu-l commented Jul 10, 2025

I can confirm @ezyang 's PR pytorch/pytorch#158030 fixed the DTensor + complex number bug. What a life saver!
I'll work with @wwwjn to come up with fix for DeepSeek TP.

wwwjn and others added 11 commits July 10, 2025 13:05
## Contents
1. Attention module
2. MoE module (note: I only implemented the naive routing, not the "node
limit routing" strategy)
3. Deepseek-V3 model

Reference:
1. Deepseek-ai:
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
4. Huggingface:
https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
5. torchtitan/experiment/deepseek-v3
6. torchtitan/experiment/llama4

## TODO
- [ ] Further clean up the DeepseekV3ModelArgs class, remove unused
model args
- [ ] Test forward pass w/ torchtitan
Command to run: `NGPU=1
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml"
./run_train.sh`

## Context
1. Added model args for 4 model settings, and training config for debug
model
2. Debugged the forward pass, and the backward pass works out of pocket.
3. Reused c4-test dataset, and tiktokenizer from llama3 model for
current testing

![Screenshot 2025-06-20 at 11 52
49 AM](https://github.com/user-attachments/assets/81d938a2-9a85-4e8c-b8e1-7f9510d785c2)
…6B model (#1330)

## Context
1. Introduced a basic DSV3-16B model training config
2. Enabled FSDP/HSDP on DSV3-16B model training

## Performance
Current profiler looks like this: The `to_copy` takes to long and needs
to be optimized. The copy comes from dtype conversion in class MoE():
```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)```

With FSDP only:
<img width="1544" alt="Screenshot 2025-06-23 at 2 10 20 PM" src="https://github.com/user-attachments/assets/bcd698dc-3899-46e0-ae53-e7f8b0db13fc" />
Mostly adapted from llama4, change the TP plan based on the difference
between deepseek-v3 and llama.

Thanks @tianyu-l for the detailed walk through about deepseek-v3
attention model and TP plan! This diff is currently based on #1324 , and
we want to extract the MoE model in DSV3 and llama4 in a shared place.

Now we have:
1. FSDP
2. Activation Checkpointing
3. TP
4. CP in progress (hang due to some reason)

1. Make CP work

There are minor issue with the numerical verification: With
deterministic seed, the loss is not identical. I used `AdamW` optimizer.

1. FSDP degree=4 (blue line)
2. FSDP degree=4, TP degree = 2 (orange line)

<img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM"
src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9"
/>

With `Adam` optimizer, the loss is **exactly the same**:
<img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM"
src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb"
/>

---------

Co-authored-by: Tianyu Liu <lty@fb.com>
Current deepseek-v3 branch has a CI job broken
[model]
name = "deepseek_v3"
flavor = "debugmodel"
# test tokenizer.model, for debug purpose only
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not using test tokenizer.model anymore (now tokenizer.json and tokenizer_config.json are in the folder)

@wwwjn
Copy link
Contributor Author

wwwjn commented Jul 10, 2025

@bloc97 Thank you again for bringing up the TP issue, @tianyu-l and I solved this problem by changing k_pe into a DTensor, so in backward, DTensor will take care of the communication across TP ranks.

For example, here's the k_pe information and it's gradient information with TP=2.

[rank0]:[titan] 2025-07-10 15:54:13,625 - root - INFO - Training starts at step 1.
[rank1]:k_pe forward before expand: type=<class 'torch.distributed.tensor.DTensor'>, shape=torch.Size([8, 2048, 64]), mean=-0.01695265993475914, k_pe DTensor spec: Spec(R on (8, 2048, 64))
[rank1]:k_pe forward after expand: type=<class 'torch.distributed.tensor.DTensor'>, shape=torch.Size([8, 2048, 16, 64]), mean=-0.008218418806791306, k_pe DTensor spec: Spec(R on (8, 2048, 16, 64))
[rank1]:k forward value: type=<class 'torch.distributed.tensor.DTensor'>, shape=torch.Size([8, 2048, 16, 192]), mean=-0.0011989236809313297, k DTensors spec: Spec(S(2) on (8, 2048, 16, 192)), localshape torch.Size([8, 2048, 8, 192])
[rank0]:k_pe forward before expand: type=<class 'torch.distributed.tensor.DTensor'>, shape=torch.Size([8, 2048, 64]), mean=-0.01695265993475914, k_pe DTensor spec: Spec(R on (8, 2048, 64))
[rank0]:k_pe forward after expand: type=<class 'torch.distributed.tensor.DTensor'>, shape=torch.Size([8, 2048, 16, 64]), mean=-0.008218418806791306, k_pe DTensor spec: Spec(R on (8, 2048, 16, 64))
[rank0]:k forward value: type=<class 'torch.distributed.tensor.DTensor'>, shape=torch.Size([8, 2048, 16, 192]), mean=-0.0017902988474816084, k DTensors spec: Spec(S(2) on (8, 2048, 16, 192)), localshape torch.Size([8, 2048, 8, 192])
[rank1]:k_pe backward gradient: type=<class 'torch.distributed.tensor.DTensor'>, shape=torch.Size([8, 2048, 64]), mean=-7.41435211837338e-12, spec=Spec(P on (8, 2048, 64))
[rank0]:k_pe backward gradient: type=<class 'torch.distributed.tensor.DTensor'>, shape=torch.Size([8, 2048, 64]), mean=1.2723578267370694e-11, spec=Spec(P on (8, 2048, 64))
[rank1]:[titan] 2025-07-10 15:54:14,243 - root - INFO - step:  1  loss:  7.9985  grad_norm:  1.7722  memory:  1.15GiB(1.22%)  tps: 9,917  tflops: 0.35  mfu: 0.04%
[rank1]:[titan] 2025-07-10 15:54:14,243 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank1]:[titan] 2025-07-10 15:54:14,243 - root - INFO - Training completed
[rank0]:[titan] 2025-07-10 15:54:14,243 - root - INFO - step:  1  loss:  7.9985  grad_norm:  1.7722  memory:  1.15GiB(1.22%)  tps: 10,606  tflops: 0.38  mfu: 0.04%

During forward pass, the k_pe will change to a Replicate DTensor with shape torch.Size([8, 2048, n_heads=16, 192]). Then it sill concat with k_nope (which is a Shard(2) DTensor), and the result k will be a Shard(2) DTensor. During backward pass, the gradient of k_pe is Spec(P on (8, 2048, 64)), which is a partial DTensor.

Hopefully this change address the issue!

@tianyu-l tianyu-l merged commit d54d05a into main Jul 11, 2025
7 checks passed
@tianyu-l tianyu-l deleted the deepseek-v3 branch July 11, 2025 07:30
@tianyu-l tianyu-l restored the deepseek-v3 branch July 11, 2025 07:47
@vwxyzjn
Copy link

vwxyzjn commented Jul 14, 2025

[rank0]:[titan] 2025-07-08 15:15:47,554 - root - INFO - step: 10 loss: 9.9183 grad_norm: 1.1012 memory: 65.54GiB(68.99%) tps: 78,712 tflops: 32.66 mfu: 3.30%

Hi @wwwjn, appreciate the nice PR. I noticed in the logs that the MFU seems a bit low compared to the 30-40% MFU that llamas typically achieve. Is this expected?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants