Skip to content

Commit d54d05a

Browse files
wwwjntianyu-lH-Huang
authored
[DSV3] Adding deepseek-v3 model into torchtitan (#1373)
## Supported Features - FSDP, HSDP - Activation checkpointing - Tensor Parallel (TP) from @tianyu-l - Expert Parallel (EP) ## To be added - Modeling - Merge DeepSeek-V3 and Llama4 MoE common components - Parallelism - Context Parallel support for DeepSeek-V3 - PP support for DeepSeek-V3 @H-Huang is working on #1345 - torch.compile - Quantization - Testing - perfomance and loss converging tests - CI integration - @wwwjn will work on this after PyTorch side diffs (mentioned in #1324) get into PyTorch nightly ## Test 1. With FSDP=8, EP=2 (['dp_shard_mod_ep', 'dp_shard_in_ep'], [4, 2]) ``` [rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - step: 1 loss: 12.2616 grad_norm: 0.3918 memory: 65.53GiB(68.98%) tps: 1,482 tflops: 0.61 mfu: 0.06% [rank0]:[titan] 2025-07-08 15:15:43,068 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-08 15:15:43,543 - root - INFO - step: 2 loss: 12.0093 grad_norm: 0.5745 memory: 65.54GiB(68.99%) tps: 69,111 tflops: 28.68 mfu: 2.90% [rank0]:[titan] 2025-07-08 15:15:43,981 - root - INFO - step: 3 loss: 11.1697 grad_norm: 1.2095 memory: 65.54GiB(68.99%) tps: 74,931 tflops: 31.09 mfu: 3.14% [rank0]:[titan] 2025-07-08 15:15:44,015 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:15:44,409 - root - INFO - step: 4 loss: 10.7248 grad_norm: 1.2230 memory: 65.54GiB(68.99%) tps: 76,668 tflops: 31.81 mfu: 3.22% [rank0]:[titan] 2025-07-08 15:15:44,838 - root - INFO - step: 5 loss: 10.5484 grad_norm: 1.1633 memory: 65.54GiB(68.99%) tps: 76,416 tflops: 31.71 mfu: 3.21% [rank0]:[titan] 2025-07-08 15:15:45,339 - root - INFO - step: 6 loss: 10.3509 grad_norm: 1.1611 memory: 65.54GiB(68.99%) tps: 65,490 tflops: 27.18 mfu: 2.75% [rank0]:[titan] 2025-07-08 15:15:45,401 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:15:46,121 - root - INFO - step: 7 loss: 10.2153 grad_norm: 1.1410 memory: 65.54GiB(68.99%) tps: 41,934 tflops: 17.40 mfu: 1.76% [rank0]:[titan] 2025-07-08 15:15:46,733 - root - INFO - step: 8 loss: 10.0801 grad_norm: 1.1487 memory: 65.54GiB(68.99%) tps: 53,599 tflops: 22.24 mfu: 2.25% [rank0]:[titan] 2025-07-08 15:15:47,137 - root - INFO - step: 9 loss: 9.9781 grad_norm: 1.1257 memory: 65.54GiB(68.99%) tps: 81,051 tflops: 33.63 mfu: 3.40% [rank0]:[titan] 2025-07-08 15:15:47,554 - root - INFO - step: 10 loss: 9.9183 grad_norm: 1.1012 memory: 65.54GiB(68.99%) tps: 78,712 tflops: 32.66 mfu: 3.30% ``` 2. With FSDP=4, TP=2 ``` [rank0]:[titan] 2025-07-08 15:16:25,927 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - step: 1 loss: 12.2768 grad_norm: 0.3836 memory: 41.14GiB(43.31%) tps: 1,750 tflops: 0.73 mfu: 0.07% [rank0]:[titan] 2025-07-08 15:16:34,993 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-08 15:16:35,310 - root - INFO - step: 2 loss: 12.0284 grad_norm: 0.5423 memory: 41.29GiB(43.46%) tps: 51,796 tflops: 21.49 mfu: 2.17% [rank0]:[titan] 2025-07-08 15:16:35,605 - root - INFO - step: 3 loss: 11.2398 grad_norm: 1.2037 memory: 41.29GiB(43.46%) tps: 55,575 tflops: 23.06 mfu: 2.33% [rank0]:[titan] 2025-07-08 15:16:35,912 - root - INFO - step: 4 loss: 10.8246 grad_norm: 1.2360 memory: 41.29GiB(43.46%) tps: 53,553 tflops: 22.22 mfu: 2.25% [rank0]:[titan] 2025-07-08 15:16:36,206 - root - INFO - step: 5 loss: 10.6295 grad_norm: 1.1951 memory: 41.29GiB(43.46%) tps: 55,732 tflops: 23.13 mfu: 2.34% [rank0]:[titan] 2025-07-08 15:16:36,502 - root - INFO - step: 6 loss: 10.5240 grad_norm: 1.1296 memory: 41.29GiB(43.46%) tps: 55,564 tflops: 23.06 mfu: 2.33% [rank0]:[titan] 2025-07-08 15:16:36,793 - root - INFO - step: 7 loss: 10.3426 grad_norm: 1.1630 memory: 41.29GiB(43.46%) tps: 56,295 tflops: 23.36 mfu: 2.36% [rank0]:[titan] 2025-07-08 15:16:36,824 - root - WARNING - Dataset c4_test is being re-looped [rank0]:[titan] 2025-07-08 15:16:37,081 - root - INFO - step: 8 loss: 10.2127 grad_norm: 1.1499 memory: 41.29GiB(43.46%) tps: 57,052 tflops: 23.67 mfu: 2.39% [rank0]:[titan] 2025-07-08 15:16:37,374 - root - INFO - step: 9 loss: 10.0537 grad_norm: 1.1814 memory: 41.29GiB(43.46%) tps: 56,019 tflops: 23.25 mfu: 2.35% [rank0]:[titan] 2025-07-08 15:16:37,664 - root - INFO - step: 10 loss: 10.0311 grad_norm: 1.1082 memory: 41.29GiB(43.46%) tps: 56,504 tflops: 23.45 mfu: 2.37% ``` --------- Co-authored-by: Tianyu Liu <lty@fb.com> Co-authored-by: Howard Huang <howardhuang96@gmail.com>
1 parent 681df50 commit d54d05a

File tree

10 files changed

+1527
-0
lines changed

10 files changed

+1527
-0
lines changed

torchtitan/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77

88
# Import the built-in models here so that the corresponding register_model_spec()
99
# will be called.
10+
import torchtitan.models.deepseek_v3 # noqa: F401
1011
import torchtitan.models.llama3 # noqa: F401
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# DeepSeek-V3 in TorchTitan
2+
3+
DeepSeek-V3 is a Mixture-of-Experts (MoE) transformer model with Multi-head Latent Attention (MLA) architecture.
4+
5+
## Setup
6+
7+
### Download Tokenizer
8+
9+
```bash
10+
# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json)
11+
python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3
12+
```
13+
14+
## Training
15+
16+
### Debug Training
17+
18+
```bash
19+
# Quick debug run with small model
20+
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh
21+
```
22+
23+
### Full Model Training
24+
25+
```bash
26+
# 16B parameter model: adapted from older 16B parameter model from https://huggingface.co/deepseek-ai/deepseek-moe-16b-base
27+
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh
28+
```
29+
30+
```bash
31+
# 671B parameter model
32+
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml" ./run_train.sh
33+
```
34+
35+
36+
## Supported Features
37+
- FSDP, HSDP
38+
- Activation checkpointing
39+
- Tensor Parallel (TP)
40+
- Expert Parallel (EP)
41+
42+
43+
## To be added
44+
- Modeling
45+
- Merge DeepSeek-V3 and Llama4 MoE common components
46+
- Attention Layer: need to pass softmax_scale to sdpa() to support scaling
47+
- Parallelism
48+
- Context Parallel support for DeepSeek-V3
49+
- PP support for DeepSeek-V3
50+
- torch.compile
51+
- Quantization
52+
- Testing
53+
- perfomance and loss converging tests
54+
- CI integration
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.tokenizer import build_hf_tokenizer
12+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13+
from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers
14+
15+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16+
17+
from .infra.parallelize import parallelize_deepseekv3
18+
from .model.args import DeepSeekV3ModelArgs
19+
from .model.model import DeepSeekV3Model
20+
21+
__all__ = [
22+
"parallelize_deepseekv3",
23+
"DeepseekV3ModelArgs",
24+
"DeepseekV3Model",
25+
"deepseekv3_configs",
26+
]
27+
28+
29+
deepseekv3_configs = {
30+
"debugmodel": DeepSeekV3ModelArgs(
31+
vocab_size=102400,
32+
dim=256,
33+
inter_dim=1024,
34+
moe_inter_dim=256,
35+
n_layers=3,
36+
n_dense_layers=1,
37+
n_heads=16,
38+
n_routed_experts=8,
39+
n_shared_experts=2,
40+
n_activated_experts=3,
41+
route_scale=1.0,
42+
q_lora_rank=0,
43+
kv_lora_rank=512,
44+
qk_nope_head_dim=128,
45+
qk_rope_head_dim=64,
46+
v_head_dim=128,
47+
mscale=0.70,
48+
),
49+
"16B": DeepSeekV3ModelArgs(
50+
vocab_size=102400,
51+
dim=2048,
52+
inter_dim=10944,
53+
moe_inter_dim=1408,
54+
n_layers=27,
55+
n_dense_layers=1,
56+
n_heads=16,
57+
n_routed_experts=64,
58+
n_shared_experts=2,
59+
n_activated_experts=6,
60+
route_scale=1.0,
61+
q_lora_rank=0,
62+
kv_lora_rank=512,
63+
qk_nope_head_dim=128,
64+
qk_rope_head_dim=64,
65+
v_head_dim=128,
66+
mscale=0.70,
67+
),
68+
"236B": DeepSeekV3ModelArgs(
69+
vocab_size=102400,
70+
dim=5120,
71+
inter_dim=12288,
72+
moe_inter_dim=1536,
73+
n_layers=60,
74+
n_dense_layers=1,
75+
n_heads=128,
76+
n_routed_experts=160,
77+
n_shared_experts=2,
78+
n_activated_experts=6,
79+
n_expert_groups=8,
80+
n_limited_groups=3,
81+
route_scale=16.0,
82+
q_lora_rank=1536,
83+
kv_lora_rank=512,
84+
qk_nope_head_dim=128,
85+
qk_rope_head_dim=64,
86+
v_head_dim=128,
87+
),
88+
"671B": DeepSeekV3ModelArgs(
89+
vocab_size=129280,
90+
dim=7168,
91+
inter_dim=18432,
92+
moe_inter_dim=2048,
93+
n_layers=61,
94+
n_dense_layers=3,
95+
n_heads=128,
96+
n_routed_experts=256,
97+
n_shared_experts=1,
98+
n_activated_experts=8,
99+
n_expert_groups=8,
100+
n_limited_groups=4,
101+
route_scale=2.5,
102+
score_func="sigmoid",
103+
q_lora_rank=1536,
104+
kv_lora_rank=512,
105+
qk_nope_head_dim=128,
106+
qk_rope_head_dim=64,
107+
v_head_dim=128,
108+
dtype="fp8",
109+
),
110+
}
111+
112+
113+
register_train_spec(
114+
TrainSpec(
115+
name="deepseek_v3",
116+
cls=DeepSeekV3Model,
117+
config=deepseekv3_configs,
118+
parallelize_fn=parallelize_deepseekv3,
119+
pipelining_fn=None,
120+
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
121+
build_lr_schedulers_fn=build_lr_schedulers,
122+
build_dataloader_fn=build_hf_dataloader,
123+
build_tokenizer_fn=build_hf_tokenizer,
124+
build_loss_fn=build_cross_entropy_loss,
125+
)
126+
)

0 commit comments

Comments
 (0)