Skip to content

Commit cf0f0c7

Browse files
wwwjntianyu-l
andcommitted
[DSV3] Apply TP on DSV3 (#1341)
Mostly adapted from llama4, change the TP plan based on the difference between deepseek-v3 and llama. Thanks @tianyu-l for the detailed walk through about deepseek-v3 attention model and TP plan! This diff is currently based on #1324 , and we want to extract the MoE model in DSV3 and llama4 in a shared place. Now we have: 1. FSDP 2. Activation Checkpointing 3. TP 4. CP in progress (hang due to some reason) 1. Make CP work There are minor issue with the numerical verification: With deterministic seed, the loss is not identical. I used `AdamW` optimizer. 1. FSDP degree=4 (blue line) 2. FSDP degree=4, TP degree = 2 (orange line) <img width="1368" alt="Screenshot 2025-07-01 at 5 38 50 PM" src="https://github.com/user-attachments/assets/38d96d75-6868-4482-a603-b9e10c692ed9" /> With `Adam` optimizer, the loss is **exactly the same**: <img width="1368" alt="Screenshot 2025-07-02 at 1 26 32 PM" src="https://github.com/user-attachments/assets/6b501d3c-4841-42b1-95fd-3971b16a5eeb" /> --------- Co-authored-by: Tianyu Liu <lty@fb.com>
1 parent 5d40bc7 commit cf0f0c7

File tree

11 files changed

+409
-208
lines changed

11 files changed

+409
-208
lines changed

torchtitan/components/tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Optional
1212

1313
from tokenizers import AddedToken, Tokenizer as HfTokenizer
14+
1415
from typing_extensions import override
1516

1617

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ tokenizer_path = "./tests/assets/test_tiktoken.model"
2626
# converters = ["float8"]
2727

2828
[optimizer]
29-
name = "AdamW"
29+
# TODO: AdamW has numerical issues when TP is used, need to fix it
30+
name = "Adam"
3031
lr = 4e-3
3132
eps = 1e-15
3233

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Download tokenizer:
2+
3+
```
4+
# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json)
5+
python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3
6+
```

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from torchtitan.components.loss import build_cross_entropy_loss
1010
from torchtitan.components.lr_scheduler import build_lr_schedulers
11-
from torchtitan.components.optimizer import build_optimizers
11+
from torchtitan.components.tokenizer import build_hf_tokenizer
1212
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13-
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
13+
from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers
1414

1515
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1616

@@ -117,10 +117,10 @@
117117
config=deepseekv3_configs,
118118
parallelize_fn=parallelize_deepseekv3,
119119
pipelining_fn=None,
120-
build_optimizers_fn=build_optimizers,
120+
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
121121
build_lr_schedulers_fn=build_lr_schedulers,
122122
build_dataloader_fn=build_hf_dataloader,
123-
build_tokenizer_fn=build_tiktoken_tokenizer,
123+
build_tokenizer_fn=build_hf_tokenizer,
124124
build_loss_fn=build_cross_entropy_loss,
125125
)
126126
)

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,19 @@
66

77
import torch.nn as nn
88
from torch.distributed.device_mesh import DeviceMesh
9+
from torch.distributed.tensor import Replicate, Shard
10+
from torch.distributed.tensor.parallel import (
11+
ColwiseParallel,
12+
parallelize_module,
13+
PrepareModuleInput,
14+
RowwiseParallel,
15+
SequenceParallel,
16+
)
917

1018
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
1119
from torchtitan.distributed import ParallelDims
20+
from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel
21+
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp
1222
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp
1323
from torchtitan.tools.logging import logger
1424

@@ -19,6 +29,47 @@ def parallelize_deepseekv3(
1929
parallel_dims: ParallelDims,
2030
job_config: JobConfig,
2131
):
32+
33+
if parallel_dims.tp_enabled:
34+
if job_config.parallelism.enable_async_tensor_parallel:
35+
# TODO(jianiw): This branch needs to be tested and enabled
36+
raise NotImplementedError(
37+
"Currently, async TP is not tested for deepseekv3. \
38+
torch.compile is not supported yet, which is required for async TP."
39+
)
40+
41+
enable_float8_linear = "float8" in job_config.model.converters
42+
float8_is_rowwise = job_config.float8.recipe_name in (
43+
"rowwise",
44+
"rowwise_with_gw_hp",
45+
)
46+
47+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
48+
if enable_float8_tensorwise_tp:
49+
# TODO(jianiw): This branch needs to be tested and enabled
50+
raise NotImplementedError(
51+
"Currently, float8 tensorwise TP is not tested for deepseekv3"
52+
)
53+
54+
apply_tp(
55+
model,
56+
world_mesh["tp"],
57+
loss_parallel=parallel_dims.loss_parallel_enabled,
58+
enable_float8_tensorwise_tp=False,
59+
enable_async_tp=False,
60+
)
61+
62+
apply_moe_ep_tp(
63+
model,
64+
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
65+
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
66+
ep_tp_mesh=(
67+
world_mesh["ep", "tp"]
68+
if parallel_dims.tp_enabled and parallel_dims.ep_enabled
69+
else None
70+
),
71+
)
72+
2273
if job_config.activation_checkpoint.mode != "none":
2374
apply_ac(model, job_config.activation_checkpoint)
2475

@@ -48,3 +99,79 @@ def parallelize_deepseekv3(
4899
logger.info("Applied FSDP to the model")
49100

50101
return model
102+
103+
104+
def apply_tp(
105+
model: nn.Module,
106+
tp_mesh: DeviceMesh,
107+
loss_parallel: bool,
108+
enable_float8_tensorwise_tp: bool,
109+
enable_async_tp: bool,
110+
):
111+
"""Apply tensor parallelism."""
112+
# 1. Parallelize the embedding and shard its outputs (which are the first
113+
# transformer block's inputs)
114+
# 2. Parallelize the root norm layer over the sequence dim
115+
# 3. Parallelize the final linear output layer
116+
parallelize_module(
117+
model,
118+
tp_mesh,
119+
{
120+
"tok_embeddings": RowwiseParallel(
121+
input_layouts=Replicate(),
122+
output_layouts=Shard(1),
123+
),
124+
"norm": SequenceParallel(),
125+
"output": ColwiseParallel(
126+
input_layouts=Shard(1),
127+
output_layouts=Shard(-1) if loss_parallel else Replicate(),
128+
use_local_output=not loss_parallel,
129+
),
130+
},
131+
)
132+
133+
rowwise_parallel, colwise_parallel, prepare_module_input = (
134+
RowwiseParallel,
135+
ColwiseParallel,
136+
PrepareModuleInput,
137+
)
138+
139+
# Apply tensor + sequence parallelism to every transformer block
140+
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
141+
# by folding (and unfolding) the batch dimension and the sequence dimension.
142+
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
143+
for transformer_block in model.layers.values():
144+
layer_plan = {
145+
"attention_norm": SequenceParallel(),
146+
"attention": prepare_module_input(
147+
input_layouts=(Shard(1), None),
148+
desired_input_layouts=(Replicate(), None),
149+
),
150+
"attention.wkv_a": NoParallel(),
151+
"attention.wkv_b": colwise_parallel(),
152+
"attention.kv_norm": NoParallel(),
153+
"attention.wq_a": NoParallel(),
154+
"attention.wq_b": colwise_parallel(),
155+
"attention.q_norm": NoParallel(),
156+
"attention.wq": colwise_parallel(), # This is only used when q_lora_rank==0
157+
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
158+
"ffn_norm": SequenceParallel(),
159+
"feed_forward": prepare_module_input(
160+
input_layouts=(Shard(1),),
161+
desired_input_layouts=(Replicate(),),
162+
),
163+
"feed_forward.w1": colwise_parallel(),
164+
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
165+
"feed_forward.w3": colwise_parallel(),
166+
}
167+
168+
parallelize_module(
169+
module=transformer_block,
170+
device_mesh=tp_mesh,
171+
parallelize_plan=layer_plan,
172+
)
173+
174+
logger.info(
175+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
176+
"Tensor Parallelism to the model"
177+
)

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
7575
n_limited_groups: int = 1
7676
score_func: Literal["softmax", "sigmoid"] = "softmax"
7777
route_scale: float = 1.0
78-
use_grouped_mm: bool = False
79-
load_balance_coeff: float | None = 1e-3
78+
use_grouped_mm: bool = True
79+
load_balance_coeff: float = 1e-3
8080
# Multi-Head Latent Attention (MLA)
8181
q_lora_rank: int = 0
8282
kv_lora_rank: int = 512
@@ -97,7 +97,7 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
9797
"""
9898
Update the model_config config from the given job config.
9999
"""
100-
self.vocab_size = tokenizer.n_words
100+
self.vocab_size = tokenizer.vocab_size
101101
self.max_seq_len = job_config.training.seq_len
102102

103103
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:

0 commit comments

Comments
 (0)