Skip to content

Commit 4244c16

Browse files
committed
add TP for norm
1 parent 0ca6ece commit 4244c16

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def parallelize_deepseekv3(
3333
if parallel_dims.tp_enabled:
3434
if job_config.parallelism.enable_async_tensor_parallel:
3535
raise NotImplementedError(
36-
"Currently, async TP is not supported for deepseekv3"
36+
"Currently, async TP is not tested for deepseekv3"
3737
)
3838

3939
enable_float8_linear = "float8" in job_config.model.converters
@@ -45,12 +45,12 @@ def parallelize_deepseekv3(
4545
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
4646
if enable_float8_tensorwise_tp:
4747
raise NotImplementedError(
48-
"Currently, float8 tensorwise TP is not supported for deepseekv3"
48+
"Currently, float8 tensorwise TP is not tested for deepseekv3"
4949
)
5050

5151
if parallel_dims.loss_parallel_enabled:
5252
raise NotImplementedError(
53-
"Currently, loss parallel is not supported for deepseekv3"
53+
"Currently, loss parallel is not tested for deepseekv3"
5454
)
5555

5656
apply_tp(
@@ -140,10 +140,12 @@ def apply_tp(
140140
input_layouts=(Shard(1), None),
141141
desired_input_layouts=(Replicate(), None),
142142
),
143-
"attention.wkv_a": NoParallel(), # Make ths a DTensor
143+
"attention.wkv_a": NoParallel(),
144144
"attention.wkv_b": colwise_parallel(),
145+
"attention.kv_norm": NoParallel(),
145146
"attention.wq_a": NoParallel(),
146147
"attention.wq_b": colwise_parallel(),
148+
"attention.q_norm": NoParallel(),
147149
"attention.wq": colwise_parallel(), # This is only used when q_lora_rank==0
148150
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
149151
"ffn_norm": SequenceParallel(),

0 commit comments

Comments
 (0)