@@ -33,7 +33,7 @@ def parallelize_deepseekv3(
33
33
if parallel_dims .tp_enabled :
34
34
if job_config .parallelism .enable_async_tensor_parallel :
35
35
raise NotImplementedError (
36
- "Currently, async TP is not supported for deepseekv3"
36
+ "Currently, async TP is not tested for deepseekv3"
37
37
)
38
38
39
39
enable_float8_linear = "float8" in job_config .model .converters
@@ -45,12 +45,12 @@ def parallelize_deepseekv3(
45
45
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
46
46
if enable_float8_tensorwise_tp :
47
47
raise NotImplementedError (
48
- "Currently, float8 tensorwise TP is not supported for deepseekv3"
48
+ "Currently, float8 tensorwise TP is not tested for deepseekv3"
49
49
)
50
50
51
51
if parallel_dims .loss_parallel_enabled :
52
52
raise NotImplementedError (
53
- "Currently, loss parallel is not supported for deepseekv3"
53
+ "Currently, loss parallel is not tested for deepseekv3"
54
54
)
55
55
56
56
apply_tp (
@@ -140,10 +140,12 @@ def apply_tp(
140
140
input_layouts = (Shard (1 ), None ),
141
141
desired_input_layouts = (Replicate (), None ),
142
142
),
143
- "attention.wkv_a" : NoParallel (), # Make ths a DTensor
143
+ "attention.wkv_a" : NoParallel (),
144
144
"attention.wkv_b" : colwise_parallel (),
145
+ "attention.kv_norm" : NoParallel (),
145
146
"attention.wq_a" : NoParallel (),
146
147
"attention.wq_b" : colwise_parallel (),
148
+ "attention.q_norm" : NoParallel (),
147
149
"attention.wq" : colwise_parallel (), # This is only used when q_lora_rank==0
148
150
"attention.wo" : rowwise_parallel (output_layouts = Shard (1 )),
149
151
"ffn_norm" : SequenceParallel (),
0 commit comments