Skip to content

Commit 0ca6ece

Browse files
committed
add tp test
1 parent b74918a commit 0ca6ece

File tree

4 files changed

+141
-13
lines changed

4 files changed

+141
-13
lines changed

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,19 @@
66

77
import torch.nn as nn
88
from torch.distributed.device_mesh import DeviceMesh
9+
from torch.distributed.tensor import Replicate, Shard
10+
from torch.distributed.tensor.parallel import (
11+
ColwiseParallel,
12+
parallelize_module,
13+
PrepareModuleInput,
14+
RowwiseParallel,
15+
SequenceParallel,
16+
)
917

1018
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
1119
from torchtitan.distributed import ParallelDims
20+
from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel
21+
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp
1222
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp
1323
from torchtitan.tools.logging import logger
1424

@@ -19,6 +29,40 @@ def parallelize_deepseekv3(
1929
parallel_dims: ParallelDims,
2030
job_config: JobConfig,
2131
):
32+
33+
if parallel_dims.tp_enabled:
34+
if job_config.parallelism.enable_async_tensor_parallel:
35+
raise NotImplementedError(
36+
"Currently, async TP is not supported for deepseekv3"
37+
)
38+
39+
enable_float8_linear = "float8" in job_config.model.converters
40+
float8_is_rowwise = job_config.float8.recipe_name in (
41+
"rowwise",
42+
"rowwise_with_gw_hp",
43+
)
44+
45+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
46+
if enable_float8_tensorwise_tp:
47+
raise NotImplementedError(
48+
"Currently, float8 tensorwise TP is not supported for deepseekv3"
49+
)
50+
51+
if parallel_dims.loss_parallel_enabled:
52+
raise NotImplementedError(
53+
"Currently, loss parallel is not supported for deepseekv3"
54+
)
55+
56+
apply_tp(
57+
model,
58+
world_mesh["tp"],
59+
loss_parallel=parallel_dims.loss_parallel_enabled,
60+
enable_float8_tensorwise_tp=False,
61+
enable_async_tp=False,
62+
)
63+
64+
apply_moe_tp(model, world_mesh["tp"])
65+
2266
if job_config.activation_checkpoint.mode != "none":
2367
apply_ac(model, job_config.activation_checkpoint)
2468

@@ -48,3 +92,77 @@ def parallelize_deepseekv3(
4892
logger.info("Applied FSDP to the model")
4993

5094
return model
95+
96+
97+
def apply_tp(
98+
model: nn.Module,
99+
tp_mesh: DeviceMesh,
100+
loss_parallel: bool,
101+
enable_float8_tensorwise_tp: bool,
102+
enable_async_tp: bool,
103+
):
104+
"""Apply tensor parallelism."""
105+
# 1. Parallelize the embedding and shard its outputs (which are the first
106+
# transformer block's inputs)
107+
# 2. Parallelize the root norm layer over the sequence dim
108+
# 3. Parallelize the final linear output layer
109+
parallelize_module(
110+
model,
111+
tp_mesh,
112+
{
113+
"tok_embeddings": RowwiseParallel(
114+
input_layouts=Replicate(),
115+
output_layouts=Shard(1),
116+
),
117+
"norm": SequenceParallel(),
118+
"output": ColwiseParallel(
119+
input_layouts=Shard(1),
120+
output_layouts=Shard(-1) if loss_parallel else Replicate(),
121+
use_local_output=not loss_parallel,
122+
),
123+
},
124+
)
125+
126+
rowwise_parallel, colwise_parallel, prepare_module_input = (
127+
RowwiseParallel,
128+
ColwiseParallel,
129+
PrepareModuleInput,
130+
)
131+
132+
# Apply tensor + sequence parallelism to every transformer block
133+
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
134+
# by folding (and unfolding) the batch dimension and the sequence dimension.
135+
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
136+
for transformer_block in model.layers.values():
137+
layer_plan = {
138+
"attention_norm": SequenceParallel(),
139+
"attention": prepare_module_input(
140+
input_layouts=(Shard(1), None),
141+
desired_input_layouts=(Replicate(), None),
142+
),
143+
"attention.wkv_a": NoParallel(), # Make ths a DTensor
144+
"attention.wkv_b": colwise_parallel(),
145+
"attention.wq_a": NoParallel(),
146+
"attention.wq_b": colwise_parallel(),
147+
"attention.wq": colwise_parallel(), # This is only used when q_lora_rank==0
148+
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
149+
"ffn_norm": SequenceParallel(),
150+
"feed_forward": prepare_module_input(
151+
input_layouts=(Shard(1),),
152+
desired_input_layouts=(Replicate(),),
153+
),
154+
"feed_forward.w1": colwise_parallel(),
155+
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
156+
"feed_forward.w3": colwise_parallel(),
157+
}
158+
159+
parallelize_module(
160+
module=transformer_block,
161+
device_mesh=tp_mesh,
162+
parallelize_plan=layer_plan,
163+
)
164+
165+
logger.info(
166+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
167+
"Tensor Parallelism to the model"
168+
)

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8+
from re import I
89
from typing import Tuple
910

1011
import torch
@@ -194,7 +195,10 @@ def forward(
194195
else:
195196
q = self.wq_b(self.q_norm(self.wq_a(x)))
196197

197-
q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim)
198+
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
199+
# local heads from sizes of q and kv as TP may have sharded them after
200+
# the above linear ops.
201+
q = q.view(bsz, seqlen, -1, self.qk_head_dim)
198202
q_nope, q_pe = torch.split(
199203
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
200204
)
@@ -211,10 +215,11 @@ def forward(
211215
kv = self.wkv_b(
212216
self.kv_norm(kv)
213217
) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
214-
kv = kv.view(bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)
218+
kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim)
215219
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
220+
n_local_heads = k_nope.size(2)
216221
k = torch.cat(
217-
[k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1
222+
[k_nope, k_pe.expand(-1, -1, n_local_heads, -1)], dim=-1
218223
) # (bsz, seqlen, n_heads, qk_head_dim)
219224

220225
q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim)
@@ -278,12 +283,13 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
278283
super().__init__()
279284
self.attention = Attention(model_args)
280285
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
281-
self.moe_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
282-
self.moe = (
283-
FeedForward(model_args.dim, model_args.inter_dim)
284-
if layer_id < model_args.n_dense_layers
285-
else MoE(model_args)
286-
)
286+
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
287+
self.moe_enabled = layer_id < model_args.n_dense_layers
288+
289+
if self.moe_enabled:
290+
self.moe = MoE(model_args)
291+
else:
292+
self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim)
287293

288294
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
289295
"""
@@ -297,7 +303,10 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
297303
torch.Tensor: Output tensor with the same shape as the input.
298304
"""
299305
x = x + self.attention(self.attention_norm(x), freqs_cis)
300-
x = x + self.moe(self.moe_norm(x))
306+
if self.moe_enabled:
307+
x = x + self.moe(self.ffn_norm(x))
308+
else:
309+
x = x + self.feed_forward(self.ffn_norm(x))
301310
return x
302311

303312

torchtitan/models/deepseek_v3/model/moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ def __init__(
116116
self.top_k = top_k
117117
self.use_sigmoid = use_sigmoid
118118
self.route_sclaing_factor = route_sclaing_factor
119-
120-
self.weight = nn.Parameter(torch.empty((self.num_experts, self.dim)))
119+
self.gate = nn.Linear(self.dim, self.num_experts, bias=False)
121120

122121
def forward(
123122
self, x: torch.Tensor, expert_bias: torch.Tensor = None
@@ -138,7 +137,7 @@ def forward(
138137
Number of tokens assigned to each expert with shape ``(num_experts,)``.
139138
"""
140139
# scores shape (bs*slen, num_experts)
141-
scores = F.linear(x, self.weight, bias=None)
140+
scores = self.gate(x)
142141

143142
# By default, sigmoid or softmax is performed in float32 to avoid loss explosion
144143
if self.use_sigmoid:

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
4949
data_parallel_replicate_degree = 1
5050
data_parallel_shard_degree = -1
5151
fsdp_reshard_after_forward = "default" # default / never / always
52+
tensor_parallel_degree = 2
53+
disable_loss_parallel = true
5254

5355
[checkpoint]
5456
enable_checkpoint = false

0 commit comments

Comments
 (0)