Skip to content

Commit e119300

Browse files
committed
rebase onto #1324
1 parent f132f1c commit e119300

File tree

8 files changed

+139
-146
lines changed

8 files changed

+139
-146
lines changed

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from torchtitan.components.loss import build_cross_entropy_loss
1010
from torchtitan.components.lr_scheduler import build_lr_schedulers
11-
from torchtitan.components.optimizer import build_optimizers
1211
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1312
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
13+
from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers
1414

1515
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1616

@@ -117,7 +117,7 @@
117117
config=deepseekv3_configs,
118118
parallelize_fn=parallelize_deepseekv3,
119119
pipelining_fn=None,
120-
build_optimizers_fn=build_optimizers,
120+
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
121121
build_lr_schedulers_fn=build_lr_schedulers,
122122
build_dataloader_fn=build_hf_dataloader,
123123
build_tokenizer_fn=build_tiktoken_tokenizer,

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
1919
from torchtitan.distributed import ParallelDims
2020
from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel
21-
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp
21+
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp
2222
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp
2323
from torchtitan.tools.logging import logger
2424

@@ -59,7 +59,16 @@ def parallelize_deepseekv3(
5959
enable_async_tp=False,
6060
)
6161

62-
apply_moe_tp(model, world_mesh["tp"])
62+
apply_moe_ep_tp(
63+
model,
64+
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
65+
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
66+
ep_tp_mesh=(
67+
world_mesh["ep", "tp"]
68+
if parallel_dims.tp_enabled and parallel_dims.ep_enabled
69+
else None
70+
),
71+
)
6372

6473
if job_config.activation_checkpoint.mode != "none":
6574
apply_ac(model, job_config.activation_checkpoint)

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
7575
n_limited_groups: int = 1
7676
score_func: Literal["softmax", "sigmoid"] = "softmax"
7777
route_scale: float = 1.0
78-
use_grouped_mm: bool = False
78+
use_grouped_mm: bool = True
7979
load_balance_coeff: float = 1e-3
8080
# Multi-Head Latent Attention (MLA)
8181
q_lora_rank: int = 0

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Tuple
99

1010
import torch
11-
import torch.nn.functional as F
1211
from torch import nn
1312
from torchtitan.models.attention import build_attention
1413
from torchtitan.protocols.train_spec import ModelProtocol
@@ -369,10 +368,15 @@ def forward(self, tokens: torch.Tensor):
369368
Returns:
370369
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
371370
"""
371+
print("Input tokens:", tokens)
372372
h = self.tok_embeddings(tokens)
373+
print("After token embedding:", h)
373374

374375
for layer in self.layers.values():
375376
h = layer(h, self.freqs_cis)
377+
print(f"After layer {layer}: ", h)
376378
h = self.norm(h)
379+
print("After normalization:", h)
377380
output = self.output(h)
381+
print("Output logits:", output)
378382
return output

0 commit comments

Comments
 (0)