Skip to content

Commit f132f1c

Browse files
committed
TP gemm works
1 parent 8adeaaf commit f132f1c

File tree

5 files changed

+26
-39
lines changed

5 files changed

+26
-39
lines changed

torchtitan/experiments/kernels/moe/indices.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def fill_indices_wrapper(
7777
max_blocks: int = 1024, # cap on total number of blocks to launch
7878
):
7979
# preallocate output
80-
print("max_len: ", max_len, "block_size: ", block_size, "max_blocks: ", max_blocks)
8180
permuted_indices = torch.full(
8281
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
8382
)

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232
dim=256,
3333
inter_dim=10944,
3434
moe_inter_dim=1408,
35-
n_layers=1,
36-
n_dense_layers=0, # no FFN layer, all MoE layers
35+
n_layers=3,
36+
n_dense_layers=1,
3737
n_heads=16,
38-
n_routed_experts=2, # hang only happens when n_routed_experts > n_activated_experts
39-
n_shared_experts=1,
40-
n_activated_experts=1,
38+
n_routed_experts=8,
39+
n_shared_experts=2,
40+
n_activated_experts=3,
4141
route_scale=1.0,
4242
q_lora_rank=0,
4343
kv_lora_rank=512,

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
7575
n_limited_groups: int = 1
7676
score_func: Literal["softmax", "sigmoid"] = "softmax"
7777
route_scale: float = 1.0
78-
use_grouped_mm: bool = True
78+
use_grouped_mm: bool = False
7979
load_balance_coeff: float = 1e-3
8080
# Multi-Head Latent Attention (MLA)
8181
q_lora_rank: int = 0

torchtitan/models/deepseek_v3/model/moe.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,20 @@ def forward(
201201
min=0,
202202
max=self.num_experts,
203203
)
204+
205+
# Reorder the token indices to match the order of the experts
204206
# token_indices_experts_sorted shape (bs*slen*top_k,)
205207
token_indices_experts_sorted = torch.argsort(
206208
selected_experts_indices.view(-1), stable=True
207209
)
210+
211+
# reorder the scores to match the order of the token indices
208212
top_scores = top_scores.view(-1)[token_indices_experts_sorted]
209213
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
210214

211215
top_scores = (
212216
top_scores * self.route_sclaing_factor
213217
) # must multiply the scaling factor
214-
print("In TokenChoiceTopKRouter, top_scores shape: ", top_scores)
215218
return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
216219

217220
def init_weights(self, init_std: float):
@@ -292,7 +295,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
292295
Returns:
293296
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
294297
"""
295-
print("In MoE input, x shape: ", x)
296298
bs, slen, dim = x.shape
297299

298300
# top_scores and selected_indices shape (bs*slen*top_k,)
@@ -303,15 +305,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
303305
num_local_tokens_per_expert,
304306
) = self.router(x.reshape(bs * slen, dim), self.expert_bias)
305307

306-
# print(
307-
# "In MoE, top_scores shape: ",
308-
# top_scores.shape,
309-
# "token_indices: ",
310-
# token_indices.shape,
311-
# "num_local_tokens: ",
312-
# num_local_tokens_per_expert.shape,
313-
# )
314-
315308
# will be used to update the expert bias for load balancing
316309
self.tokens_per_expert += num_local_tokens_per_expert
317310

@@ -324,12 +317,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
324317
dim=0,
325318
index=token_indices,
326319
)
327-
print("Routed input: ", routed_input)
328-
329-
# TODO: remove this line, this is a temporary test
330-
routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to(
331-
x.dtype
332-
)
333320

334321
if self.use_grouped_mm:
335322
# NOTE: In order to use torch._grouped_mm, we need to make sure
@@ -361,30 +348,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
361348
else:
362349
# NOTE: this would incur a synchronization between device and host
363350
num_local_tokens_per_expert = num_local_tokens_per_expert.tolist()
364-
input_shape, permuted_indices = None, None
351+
permuted_indices, input_shape = None, None
365352

366353
# shape (bs*slen*top_k, dim)
367-
routed_output = self.experts(
368-
routed_input, num_local_tokens_per_expert
369-
) # torch.Size([16384(bsz), 256])
354+
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
370355

371-
routed_output_unpermuted = routed_output.new_empty(input_shape)
372-
routed_output_unpermuted[permuted_indices, :] = routed_output
373-
routed_output = routed_output_unpermuted[:-1]
356+
if self.use_grouped_mm:
357+
# NOTE: Reverese the permutation to get the original order as inputs
358+
routed_output_unpermuted = routed_output.new_empty(input_shape)
359+
routed_output_unpermuted[permuted_indices, :] = routed_output
360+
routed_output = routed_output_unpermuted[:-1] # remove padding
374361

375-
# TODO: Use this line instead if routed_input*top_scores, need to pad top_scores to be multiple of 16
376-
# routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
377-
# x.dtype
378-
# )
362+
routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
363+
x.dtype
364+
)
379365

380366
# shared expert
381367
if self.shared_expert is not None:
382368
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
383369
bs * slen, dim
384-
) # torch.Size([16384, 256]) None
370+
)
385371
else:
386372
out = torch.zeros_like(x.reshape(bs * slen, dim))
387373

374+
# Accumulate multiple expert results becase each token can be routed to multiple experts
388375
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
389376
out = out.reshape(bs, slen, dim)
390377
return out

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot"
1515
[metrics]
1616
log_freq = 1
1717
disable_color_printing = false
18-
enable_tensorboard = false
18+
enable_tensorboard = true
1919
save_tb_folder = "tb"
2020
enable_wandb = false
2121

@@ -41,15 +41,16 @@ lr_min = 0.0
4141
local_batch_size = 16
4242
seq_len = 2048
4343
max_norm = 1.0 # grad norm clipping
44-
steps = 2
44+
steps = 10
4545
compile = false
4646
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
47+
seed = 0
4748

4849
[parallelism]
4950
data_parallel_replicate_degree = 1
5051
data_parallel_shard_degree = -1
5152
fsdp_reshard_after_forward = "default" # default / never / always
52-
tensor_parallel_degree = 2
53+
tensor_parallel_degree = 1
5354
enable_async_tensor_parallel = false
5455

5556
[checkpoint]

0 commit comments

Comments
 (0)