Skip to content

Commit 8adeaaf

Browse files
committed
tp on groupped_mm finished
1 parent f2752b3 commit 8adeaaf

File tree

5 files changed

+33
-34
lines changed

5 files changed

+33
-34
lines changed

torchtitan/experiments/kernels/moe/indices.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def fill_indices_wrapper(
7777
max_blocks: int = 1024, # cap on total number of blocks to launch
7878
):
7979
# preallocate output
80+
print("max_len: ", max_len, "block_size: ", block_size, "max_blocks: ", max_blocks)
8081
permuted_indices = torch.full(
8182
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
8283
)

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,8 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
295295
"""
296296
x = x + self.attention(self.attention_norm(x), freqs_cis)
297297
if self.moe_enabled:
298-
print(f"In TransformerBlock {self.layer_id}: MoE is enabled")
299298
x = x + self.moe(self.ffn_norm(x))
300299
else:
301-
print(f"In TransformerBlock {self.layer_id}: FFN is enabled")
302300
x = x + self.feed_forward(self.ffn_norm(x))
303301
return x
304302

@@ -327,7 +325,6 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
327325

328326
self.layers = torch.nn.ModuleDict()
329327
for layer_id in range(model_args.n_layers):
330-
print(f"Create layer: {layer_id}")
331328
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
332329

333330
self.norm = nn.RMSNorm(model_args.dim)

torchtitan/models/deepseek_v3/model/moe.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def forward(
211211
top_scores = (
212212
top_scores * self.route_sclaing_factor
213213
) # must multiply the scaling factor
214-
214+
print("In TokenChoiceTopKRouter, top_scores shape: ", top_scores)
215215
return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
216216

217217
def init_weights(self, init_std: float):
@@ -253,12 +253,6 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
253253
)
254254
if model_args.n_shared_experts > 0
255255
else None
256-
# FeedForward(
257-
# dim=dim,
258-
# hidden_dim=hidden_dim * model_args.n_shared_experts,
259-
# )
260-
# if model_args.n_shared_experts > 0
261-
# else None
262256
)
263257

264258
# auxiliary-loss-free load balancing
@@ -298,6 +292,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
298292
Returns:
299293
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
300294
"""
295+
print("In MoE input, x shape: ", x)
301296
bs, slen, dim = x.shape
302297

303298
# top_scores and selected_indices shape (bs*slen*top_k,)
@@ -308,14 +303,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
308303
num_local_tokens_per_expert,
309304
) = self.router(x.reshape(bs * slen, dim), self.expert_bias)
310305

311-
print(
312-
"In MoE, top_scores shape: ",
313-
top_scores.shape,
314-
"token_indices: ",
315-
token_indices.shape,
316-
"num_local_tokens: ",
317-
num_local_tokens_per_expert.shape,
318-
)
306+
# print(
307+
# "In MoE, top_scores shape: ",
308+
# top_scores.shape,
309+
# "token_indices: ",
310+
# token_indices.shape,
311+
# "num_local_tokens: ",
312+
# num_local_tokens_per_expert.shape,
313+
# )
319314

320315
# will be used to update the expert bias for load balancing
321316
self.tokens_per_expert += num_local_tokens_per_expert
@@ -329,6 +324,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
329324
dim=0,
330325
index=token_indices,
331326
)
327+
print("Routed input: ", routed_input)
328+
329+
# TODO: remove this line, this is a temporary test
330+
routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to(
331+
x.dtype
332+
)
332333

333334
if self.use_grouped_mm:
334335
# NOTE: In order to use torch._grouped_mm, we need to make sure
@@ -350,28 +351,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
350351
num_local_tokens_per_expert,
351352
self.experts.num_experts,
352353
1,
353-
token_indices[0] + self.experts.num_experts * ALIGN_SIZE_M,
354+
token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M,
354355
ALIGN_SIZE_M,
355356
)
356-
token_indices = torch.vstack(
357-
(token_indices, token_indices.new_zeros((dim)))
358-
)
359-
token_indices = token_indices[permuted_indices, :]
357+
360358
routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim))))
359+
input_shape = routed_input.shape
361360
routed_input = routed_input[permuted_indices, :]
362361
else:
363362
# NOTE: this would incur a synchronization between device and host
364363
num_local_tokens_per_expert = num_local_tokens_per_expert.tolist()
364+
input_shape, permuted_indices = None, None
365365

366-
print("Num local tokens per expert: ", num_local_tokens_per_expert)
367366
# shape (bs*slen*top_k, dim)
368367
routed_output = self.experts(
369368
routed_input, num_local_tokens_per_expert
370369
) # torch.Size([16384(bsz), 256])
371-
print("Routed output shape: ", routed_output.shape)
372-
routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
373-
x.dtype
374-
)
370+
371+
routed_output_unpermuted = routed_output.new_empty(input_shape)
372+
routed_output_unpermuted[permuted_indices, :] = routed_output
373+
routed_output = routed_output_unpermuted[:-1]
374+
375+
# TODO: Use this line instead if routed_input*top_scores, need to pad top_scores to be multiple of 16
376+
# routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
377+
# x.dtype
378+
# )
375379

376380
# shared expert
377381
if self.shared_expert is not None:
@@ -381,10 +385,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
381385
else:
382386
out = torch.zeros_like(x.reshape(bs * slen, dim))
383387

384-
print(
385-
"Out shape: ", out.shape, out.grad.shape if out.grad is not None else None
386-
)
387-
388388
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
389389
out = out.reshape(bs, slen, dim)
390390
return out

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ max_norm = 1.0 # grad norm clipping
4545
steps = 10
4646
compile = false
4747
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
48+
seed = 0
4849

4950
[parallelism]
5051
data_parallel_replicate_degree = 1

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ lr_min = 0.0
4141
local_batch_size = 16
4242
seq_len = 2048
4343
max_norm = 1.0 # grad norm clipping
44-
steps = 10
44+
steps = 2
4545
compile = false
4646
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
4747

4848
[parallelism]
4949
data_parallel_replicate_degree = 1
5050
data_parallel_shard_degree = -1
5151
fsdp_reshard_after_forward = "default" # default / never / always
52-
tensor_parallel_degree = 1
52+
tensor_parallel_degree = 2
5353
enable_async_tensor_parallel = false
5454

5555
[checkpoint]

0 commit comments

Comments
 (0)