Skip to content

Commit 3997323

Browse files
committed
add 671B model config
1 parent d7bceb0 commit 3997323

File tree

6 files changed

+119
-24
lines changed

6 files changed

+119
-24
lines changed

torchtitan/models/deepseek_v3/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./r
2727
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh
2828
```
2929

30+
```bash
31+
# 671B parameter model
32+
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml" ./run_train.sh
33+
```
34+
3035

3136
## Supported Features
3237
- FSDP, HSDP
@@ -36,6 +41,8 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml"
3641

3742

3843
## To be added
44+
- TP:
45+
- TP has a known numerical issue with DeepSeek-V3 (https://github.com/pytorch/torchtitan/pull/1373#issuecomment-3050249520).
3946
- Modeling
4047
- Merge DeepSeek-V3 and Llama4 MoE common components
4148
- Parallelism

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
"debugmodel": DeepSeekV3ModelArgs(
3131
vocab_size=102400,
3232
dim=256,
33-
inter_dim=10944,
34-
moe_inter_dim=1408,
33+
inter_dim=1024,
34+
moe_inter_dim=256,
3535
n_layers=3,
3636
n_dense_layers=1,
3737
n_heads=16,

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ColwiseParallel,
1212
parallelize_module,
1313
PrepareModuleInput,
14+
PrepareModuleInputOutput,
1415
RowwiseParallel,
1516
SequenceParallel,
1617
)
@@ -51,7 +52,7 @@ def parallelize_deepseekv3(
5152
"Currently, float8 tensorwise TP is not tested for deepseekv3"
5253
)
5354

54-
apply_tp(
55+
apply_non_moe_tp(
5556
model,
5657
world_mesh["tp"],
5758
loss_parallel=parallel_dims.loss_parallel_enabled,
@@ -133,7 +134,7 @@ def parallelize_deepseekv3(
133134
return model
134135

135136

136-
def apply_tp(
137+
def apply_non_moe_tp(
137138
model: nn.Module,
138139
tp_mesh: DeviceMesh,
139140
loss_parallel: bool,
@@ -145,6 +146,7 @@ def apply_tp(
145146
# transformer block's inputs)
146147
# 2. Parallelize the root norm layer over the sequence dim
147148
# 3. Parallelize the final linear output layer
149+
logger.warning("There are known issue with TP for deepseekv3. Please see details in discussion: https://github.com/pytorch/torchtitan/pull/1373#issuecomment-3050249520.")
148150
parallelize_module(
149151
model,
150152
tp_mesh,
@@ -182,21 +184,36 @@ def apply_tp(
182184
"attention.wkv_a": NoParallel(),
183185
"attention.wkv_b": colwise_parallel(),
184186
"attention.kv_norm": NoParallel(),
185-
"attention.wq_a": NoParallel(),
186-
"attention.wq_b": colwise_parallel(),
187-
"attention.q_norm": NoParallel(),
188-
"attention.wq": colwise_parallel(), # This is only used when q_lora_rank==0
189187
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
190188
"ffn_norm": SequenceParallel(),
191-
"feed_forward": prepare_module_input(
192-
input_layouts=(Shard(1),),
193-
desired_input_layouts=(Replicate(),),
194-
),
195-
"feed_forward.w1": colwise_parallel(),
196-
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
197-
"feed_forward.w3": colwise_parallel(),
198189
}
199190

191+
if transformer_block.attention.q_lora_rank == 0:
192+
layer_plan.update(
193+
{
194+
"attention.wq": colwise_parallel(), # This is only used when q_lora_rank==0
195+
}
196+
)
197+
else:
198+
layer_plan.update(
199+
{
200+
"attention.wq_a": NoParallel(),
201+
"attention.wq_b": colwise_parallel(),
202+
"attention.q_norm": NoParallel(),
203+
}
204+
)
205+
206+
if not transformer_block.moe_enabled:
207+
layer_plan.update({
208+
"feed_forward": prepare_module_input(
209+
input_layouts=(Shard(1),),
210+
desired_input_layouts=(Replicate(),),
211+
),
212+
"feed_forward.w1": colwise_parallel(),
213+
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
214+
"feed_forward.w3": colwise_parallel(),
215+
})
216+
200217
parallelize_module(
201218
module=transformer_block,
202219
device_mesh=tp_mesh,

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ enable_wandb = false
2424
name = "deepseek_v3"
2525
flavor = "debugmodel"
2626
# test tokenizer.model, for debug purpose only
27-
tokenizer_path = "./assets/tokenizer/DeepSeek-V3"
27+
tokenizer_path = "./tests/assets/tokenizer"
2828
# converters = ["float8"]
2929

3030
[optimizer]
@@ -40,7 +40,7 @@ lr_min = 0.0
4040

4141
[training]
4242
local_batch_size = 8
43-
seq_len = 4096
43+
seq_len = 2048
4444
max_norm = 1.0 # grad norm clipping
4545
steps = 10
4646
compile = false
@@ -69,4 +69,5 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
6969
[float8]
7070
enable_fsdp_float8_all_gather = false
7171
precompute_float8_dynamic_scale_for_fsdp = false
72-
filter_fqns = ["output"]
72+
filter_fqns = ["output", "router.gate"]
73+
moe_fqns = ["experts"]

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,25 @@ enable_wandb = false
2222
[model]
2323
name = "deepseek_v3"
2424
flavor = "16B"
25-
# test tokenizer.model, for debug purpose only
2625
tokenizer_path = "./assets/tokenizer/DeepSeek-V3"
2726
# converters = ["float8"]
2827

2928
[optimizer]
3029
name = "AdamW"
31-
lr = 8e-4
30+
lr = 2.2e-4
3231
eps = 1e-8
3332

3433
[lr_scheduler]
35-
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
34+
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
3635
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
3736
decay_type = "linear"
38-
lr_min = 0.0
37+
lr_min = 2.2e-5
3938

4039
[training]
4140
local_batch_size = 8
4241
seq_len = 4096
4342
max_norm = 1.0 # grad norm clipping
44-
steps = 100
43+
steps = 1000
4544
compile = false
4645
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
4746

@@ -67,4 +66,5 @@ mode = "full" # ["none", "selective", "full"]
6766
[float8]
6867
enable_fsdp_float8_all_gather = false
6968
precompute_float8_dynamic_scale_for_fsdp = false
70-
filter_fqns = ["output"]
69+
filter_fqns = ["output", "router.gate"]
70+
moe_fqns = ["experts"]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# torchtitan Config.toml
2+
3+
[job]
4+
dump_folder = "./outputs"
5+
description = "DeepSeek-V3 671B model training"
6+
print_args = false
7+
8+
[profiling]
9+
enable_profiling = false
10+
save_traces_folder = "profile_trace"
11+
profile_freq = 10
12+
enable_memory_snapshot = false
13+
save_memory_snapshot_folder = "memory_snapshot"
14+
15+
[metrics]
16+
log_freq = 10
17+
disable_color_printing = false
18+
enable_tensorboard = false
19+
save_tb_folder = "tb"
20+
enable_wandb = false
21+
22+
[model]
23+
name = "deepseek_v3"
24+
flavor = "671B"
25+
tokenizer_path = "./assets/tokenizer/DeepSeek-V3"
26+
# converters = ["float8"]
27+
28+
[optimizer]
29+
name = "AdamW"
30+
lr = 2.2e-4
31+
eps = 1e-8
32+
33+
[lr_scheduler]
34+
warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps
35+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
36+
decay_type = "linear"
37+
lr_min = 2.2e-5
38+
39+
[training]
40+
local_batch_size = 4
41+
seq_len = 4096
42+
max_norm = 1.0 # grad norm clipping
43+
steps = 10_000
44+
compile = false
45+
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
46+
47+
[parallelism]
48+
data_parallel_replicate_degree = 1
49+
data_parallel_shard_degree = -1
50+
fsdp_reshard_after_forward = "default" # default / never / always
51+
tensor_parallel_degree = 8
52+
enable_async_tensor_parallel = false
53+
expert_parallel_degree = 1
54+
55+
[checkpoint]
56+
enable_checkpoint = false
57+
folder = "checkpoint"
58+
interval = 500
59+
last_save_model_weights_only = false
60+
export_dtype = "float32"
61+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]"
62+
63+
[activation_checkpoint]
64+
mode = "full" # ["none", "selective", "full"]
65+
66+
[float8]
67+
enable_fsdp_float8_all_gather = false
68+
precompute_float8_dynamic_scale_for_fsdp = false
69+
filter_fqns = ["output", "router.gate"]
70+
moe_fqns = ["experts"]

0 commit comments

Comments
 (0)