6
6
7
7
import torch .nn as nn
8
8
from torch .distributed .device_mesh import DeviceMesh
9
+ from torch .distributed .tensor import Replicate , Shard
10
+ from torch .distributed .tensor .parallel import (
11
+ ColwiseParallel ,
12
+ parallelize_module ,
13
+ PrepareModuleInput ,
14
+ RowwiseParallel ,
15
+ SequenceParallel ,
16
+ )
9
17
10
18
from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
11
19
from torchtitan .distributed import ParallelDims
20
+ from torchtitan .experiments .llama4 .infra .expert_parallel import NoParallel
21
+ from torchtitan .experiments .llama4 .infra .parallelize import apply_moe_tp
12
22
from torchtitan .models .llama3 .infra .parallelize import apply_ac , apply_fsdp
13
23
from torchtitan .tools .logging import logger
14
24
@@ -19,6 +29,40 @@ def parallelize_deepseekv3(
19
29
parallel_dims : ParallelDims ,
20
30
job_config : JobConfig ,
21
31
):
32
+
33
+ if parallel_dims .tp_enabled :
34
+ if job_config .parallelism .enable_async_tensor_parallel :
35
+ raise NotImplementedError (
36
+ "Currently, async TP is not supported for deepseekv3"
37
+ )
38
+
39
+ enable_float8_linear = "float8" in job_config .model .converters
40
+ float8_is_rowwise = job_config .float8 .recipe_name in (
41
+ "rowwise" ,
42
+ "rowwise_with_gw_hp" ,
43
+ )
44
+
45
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
46
+ if enable_float8_tensorwise_tp :
47
+ raise NotImplementedError (
48
+ "Currently, float8 tensorwise TP is not supported for deepseekv3"
49
+ )
50
+
51
+ if parallel_dims .loss_parallel_enabled :
52
+ raise NotImplementedError (
53
+ "Currently, loss parallel is not supported for deepseekv3"
54
+ )
55
+
56
+ apply_tp (
57
+ model ,
58
+ world_mesh ["tp" ],
59
+ loss_parallel = parallel_dims .loss_parallel_enabled ,
60
+ enable_float8_tensorwise_tp = False ,
61
+ enable_async_tp = False ,
62
+ )
63
+
64
+ apply_moe_tp (model , world_mesh ["tp" ])
65
+
22
66
if job_config .activation_checkpoint .mode != "none" :
23
67
apply_ac (model , job_config .activation_checkpoint )
24
68
@@ -48,3 +92,77 @@ def parallelize_deepseekv3(
48
92
logger .info ("Applied FSDP to the model" )
49
93
50
94
return model
95
+
96
+
97
+ def apply_tp (
98
+ model : nn .Module ,
99
+ tp_mesh : DeviceMesh ,
100
+ loss_parallel : bool ,
101
+ enable_float8_tensorwise_tp : bool ,
102
+ enable_async_tp : bool ,
103
+ ):
104
+ """Apply tensor parallelism."""
105
+ # 1. Parallelize the embedding and shard its outputs (which are the first
106
+ # transformer block's inputs)
107
+ # 2. Parallelize the root norm layer over the sequence dim
108
+ # 3. Parallelize the final linear output layer
109
+ parallelize_module (
110
+ model ,
111
+ tp_mesh ,
112
+ {
113
+ "tok_embeddings" : RowwiseParallel (
114
+ input_layouts = Replicate (),
115
+ output_layouts = Shard (1 ),
116
+ ),
117
+ "norm" : SequenceParallel (),
118
+ "output" : ColwiseParallel (
119
+ input_layouts = Shard (1 ),
120
+ output_layouts = Shard (- 1 ) if loss_parallel else Replicate (),
121
+ use_local_output = not loss_parallel ,
122
+ ),
123
+ },
124
+ )
125
+
126
+ rowwise_parallel , colwise_parallel , prepare_module_input = (
127
+ RowwiseParallel ,
128
+ ColwiseParallel ,
129
+ PrepareModuleInput ,
130
+ )
131
+
132
+ # Apply tensor + sequence parallelism to every transformer block
133
+ # NOTE: At the cost of model code change, we can accelerate Sequence Parallel
134
+ # by folding (and unfolding) the batch dimension and the sequence dimension.
135
+ # Examples can be found at https://github.com/pytorch/torchtitan/pull/437
136
+ for transformer_block in model .layers .values ():
137
+ layer_plan = {
138
+ "attention_norm" : SequenceParallel (),
139
+ "attention" : prepare_module_input (
140
+ input_layouts = (Shard (1 ), None ),
141
+ desired_input_layouts = (Replicate (), None ),
142
+ ),
143
+ "attention.wkv_a" : NoParallel (), # Make ths a DTensor
144
+ "attention.wkv_b" : colwise_parallel (),
145
+ "attention.wq_a" : NoParallel (),
146
+ "attention.wq_b" : colwise_parallel (),
147
+ "attention.wq" : colwise_parallel (), # This is only used when q_lora_rank==0
148
+ "attention.wo" : rowwise_parallel (output_layouts = Shard (1 )),
149
+ "ffn_norm" : SequenceParallel (),
150
+ "feed_forward" : prepare_module_input (
151
+ input_layouts = (Shard (1 ),),
152
+ desired_input_layouts = (Replicate (),),
153
+ ),
154
+ "feed_forward.w1" : colwise_parallel (),
155
+ "feed_forward.w2" : rowwise_parallel (output_layouts = Shard (1 )),
156
+ "feed_forward.w3" : colwise_parallel (),
157
+ }
158
+
159
+ parallelize_module (
160
+ module = transformer_block ,
161
+ device_mesh = tp_mesh ,
162
+ parallelize_plan = layer_plan ,
163
+ )
164
+
165
+ logger .info (
166
+ f"Applied { 'Float8 tensorwise ' if enable_float8_tensorwise_tp else '' } { 'Async ' if enable_async_tp else '' } "
167
+ "Tensor Parallelism to the model"
168
+ )
0 commit comments