Skip to content

Commit df8288b

Browse files
committed
[DSV3] Add PP support for DSV3
1 parent 7aff172 commit df8288b

File tree

6 files changed

+336
-10
lines changed

6 files changed

+336
-10
lines changed

torchtitan/models/deepseek_v3/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml"
3333
- Activation checkpointing
3434
- Tensor Parallel (TP)
3535
- Expert Parallel (EP)
36+
- Pipeline Parallel (PP)
3637

3738

3839
## To be added
3940
- Modeling
4041
- Merge DeepSeek-V3 and Llama4 MoE common components
4142
- Parallelism
4243
- Context Parallel support for DeepSeek-V3
43-
- PP support for DeepSeek-V3
4444
- torch.compile
4545
- Quantization
4646
- Testing

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1616

1717
from .infra.parallelize import parallelize_deepseekv3
18+
from .infra.pipeline import pipeline_deepseekv3
1819
from .model.args import DeepSeekV3ModelArgs
1920
from .model.model import DeepSeekV3Model
2021

@@ -116,7 +117,7 @@
116117
cls=DeepSeekV3Model,
117118
config=deepseekv3_configs,
118119
parallelize_fn=parallelize_deepseekv3,
119-
pipelining_fn=None,
120+
pipelining_fn=pipeline_deepseekv3,
120121
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
121122
build_lr_schedulers_fn=build_lr_schedulers,
122123
build_dataloader_fn=build_hf_dataloader,
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file applies the PT-D pipeline parallelism to the Llama model.
8+
9+
import copy
10+
11+
import torch
12+
13+
import torch.nn as nn
14+
from torch.distributed import DeviceMesh
15+
from torch.distributed.pipelining import PipelineStage
16+
from torch.distributed.pipelining.schedules import (
17+
_PipelineSchedule,
18+
get_schedule_class,
19+
PipelineScheduleSingle,
20+
ScheduleZBVZeroBubble,
21+
)
22+
23+
from torchtitan.components.loss import LossFunction
24+
from torchtitan.config_manager import JobConfig
25+
from torchtitan.distributed import ParallelDims
26+
from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank
27+
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
28+
from torchtitan.tools.logging import logger
29+
30+
from ..model.args import DeepSeekV3ModelArgs
31+
32+
def generate_module_names_per_stage(
33+
num_stages: int,
34+
num_layers: int,
35+
input_weight: int = 1,
36+
output_weight: int = 1,
37+
) -> list[list[str]]:
38+
"""
39+
Programmatically generates module names per stage for pipeline parallelism with weighting.
40+
41+
Args:
42+
num_stages: Number of pipeline stages
43+
num_layers: Total number of transformer layers in the model
44+
input_weight: Weight for input modules (tok_embeddings) in layer calculation
45+
output_weight: Weight for output modules (norm + output) in layer calculation
46+
47+
Returns:
48+
List of lists containing module names for each stage
49+
50+
Example:
51+
generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2)
52+
treats embeddings as 2 layers and norm+output as 2 layers for distribution
53+
"""
54+
if num_stages < 1:
55+
raise ValueError("Number of stages must be at least 1")
56+
57+
if num_stages == 1:
58+
# Single stage gets everything
59+
layer_names = [f"layers.{i}" for i in range(num_layers)]
60+
return [["tok_embeddings"] + layer_names + ["norm", "output"]]
61+
62+
# Calculate effective layers including weights
63+
num_effective_layers = num_layers + input_weight + output_weight
64+
65+
if num_stages > num_effective_layers:
66+
raise ValueError(
67+
f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})"
68+
)
69+
70+
# Calculate layers per stage (distribute evenly)
71+
layers_per_stage = num_effective_layers // num_stages
72+
extra_layers = num_effective_layers % num_stages
73+
74+
# Ensure each stage gets at least the weight of input/output modules
75+
if layers_per_stage < max(input_weight, output_weight):
76+
raise ValueError(
77+
f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})"
78+
)
79+
80+
module_names_per_stage = []
81+
current_layer = 0
82+
83+
for stage_idx in range(num_stages):
84+
stage_modules = []
85+
86+
# Calculate effective layers for this stage
87+
effective_layers_for_stage = layers_per_stage
88+
if stage_idx < extra_layers:
89+
effective_layers_for_stage += 1
90+
91+
# First stage: handle input modules with weighting
92+
if stage_idx == 0:
93+
stage_modules.append("tok_embeddings")
94+
# Account for input weight in layer distribution
95+
remaining_layers_for_stage = effective_layers_for_stage - input_weight
96+
97+
# Add transformer layers
98+
for _ in range(remaining_layers_for_stage):
99+
if current_layer < num_layers:
100+
stage_modules.append(f"layers.{current_layer}")
101+
current_layer += 1
102+
103+
# Last stage: handle output modules with weighting
104+
elif stage_idx == num_stages - 1:
105+
# Account for output weight in layer distribution
106+
remaining_layers_for_stage = effective_layers_for_stage - output_weight
107+
108+
# Add transformer layers
109+
for _ in range(remaining_layers_for_stage):
110+
if current_layer < num_layers:
111+
stage_modules.append(f"layers.{current_layer}")
112+
current_layer += 1
113+
114+
# Add output modules
115+
stage_modules.extend(["norm", "output"])
116+
117+
# Middle stages: only transformer layers
118+
else:
119+
for _ in range(effective_layers_for_stage):
120+
if current_layer < num_layers:
121+
stage_modules.append(f"layers.{current_layer}")
122+
current_layer += 1
123+
124+
module_names_per_stage.append(stage_modules)
125+
126+
return module_names_per_stage
127+
128+
129+
def pipeline_deepseekv3(
130+
model: nn.Module,
131+
world_mesh: DeviceMesh,
132+
parallel_dims: ParallelDims,
133+
job_config: JobConfig,
134+
device: DeviceType,
135+
model_config: DeepSeekV3ModelArgs,
136+
parallelize_fn: ParallelizeFunction,
137+
loss_fn: LossFunction,
138+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
139+
pp_mesh = world_mesh["pp"]
140+
141+
# Determine the number of virtual stages based on schedule type
142+
schedule_class = get_schedule_class(
143+
job_config.parallelism.pipeline_parallel_schedule
144+
)
145+
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
146+
147+
# For multi-stage schedules, default is 2 virtual stages per rank
148+
# For single-stage schedules, default is 1 virtual stage per rank
149+
stages_per_rank = 1 if is_single_stage_schedule else 2
150+
num_virtual_stages = parallel_dims.pp * stages_per_rank
151+
152+
# Generate module names per stage programmatically with weighting
153+
num_layers = model_config.n_layers
154+
155+
# You can adjust these weights based on the computational cost of embeddings and output layers
156+
# Higher weights mean these modules are treated as "heavier" in the distribution
157+
input_weight = 1 # Weight for tok_embeddings
158+
output_weight = 1 # Weight for norm + output layers
159+
160+
module_names_per_stage = generate_module_names_per_stage(
161+
num_virtual_stages, num_layers, input_weight, output_weight
162+
)
163+
for i, stage_ms in enumerate(module_names_per_stage):
164+
logger.info(f"Stage {i}: {stage_ms}")
165+
166+
stages, model_parts = pipeline_module_split(
167+
model,
168+
pp_mesh,
169+
job_config.parallelism.pipeline_parallel_schedule,
170+
device,
171+
module_names_per_stage,
172+
)
173+
174+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
175+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
176+
# optimizer, and checkpointing
177+
for i, m in enumerate(model_parts):
178+
# apply SPMD-style PT-D techniques
179+
m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
180+
model_parts[i] = m
181+
# NOTE: this is to update the model in the stage
182+
# in case the model is modified e.g. by torch.compile
183+
stages[i].submod = m
184+
185+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
186+
187+
# This is used in the train loop to determine whether to pass in the input_ids and labels
188+
has_first_stage = False
189+
has_last_stage = False
190+
for stage in stages:
191+
if stage.is_first:
192+
has_first_stage = True
193+
if stage.is_last:
194+
has_last_stage = True
195+
196+
return pp_schedule, model_parts, has_first_stage, has_last_stage
197+
198+
199+
def pipeline_module_split(
200+
whole_model: nn.Module,
201+
pp_mesh: DeviceMesh,
202+
pp_schedule: str,
203+
device: DeviceType,
204+
module_names_per_stage: list[list[str]],
205+
) -> tuple[list[PipelineStage], list[nn.Module]]:
206+
"""
207+
This API creates pipeline stages based on specified module names for each stage.
208+
209+
Args:
210+
whole_model: The complete model to be split
211+
pp_mesh: Pipeline parallel device mesh
212+
pp_schedule: Name of pipeline parallelism schedule
213+
device: Device type
214+
module_names_per_stage: List of lists, where each inner list contains the module names
215+
that should be included in that stage. Module names should be
216+
dot-separated paths. Examples:
217+
- "tok_embeddings" for token embeddings
218+
- "layers.0", "layers.1" for specific transformer layers
219+
- "norm" for the final normalization layer
220+
- "output" for the output projection layer
221+
222+
Returns:
223+
Tuple of (stages, models) where stages are PipelineStage objects and models are the
224+
corresponding model chunks
225+
226+
Example usage:
227+
module_names_per_stage = [
228+
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
229+
["layers.1", "layers.2"], # Stage 1: middle layers
230+
["norm", "output"] # Stage 2: final norm + output
231+
]
232+
"""
233+
pp_rank = pp_mesh.get_local_rank()
234+
pp_size = pp_mesh.size()
235+
236+
def _build_stage_from_modules(
237+
stage_idx: int, module_names: list[str], num_stages: int
238+
) -> tuple[PipelineStage, nn.Module]:
239+
model = copy.deepcopy(whole_model)
240+
241+
# Create a set of modules to keep for faster lookup
242+
modules_to_keep = set(module_names)
243+
print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
244+
for module_name, module_value in model.named_children():
245+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
246+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
247+
layers_to_keep = {
248+
name.split(".", 1)[1]
249+
for name in modules_to_keep
250+
if name.startswith(f"{module_name}.")
251+
}
252+
if layers_to_keep:
253+
# Keep only specified layers
254+
if isinstance(module_value, nn.ModuleDict):
255+
for layer_name in list(module_value.keys()):
256+
if layer_name not in layers_to_keep:
257+
del module_value[layer_name]
258+
elif isinstance(module_value, nn.ModuleList):
259+
indices_to_keep = {
260+
int(idx) for idx in layers_to_keep if idx.isdigit()
261+
}
262+
new_layers = nn.ModuleList(
263+
[
264+
layer
265+
for i, layer in enumerate(module_value)
266+
if i in indices_to_keep
267+
]
268+
)
269+
setattr(model, module_name, new_layers)
270+
else:
271+
# No layers from this structure needed, set to empty structure
272+
if isinstance(module_value, nn.ModuleDict):
273+
setattr(model, module_name, nn.ModuleDict())
274+
elif isinstance(module_value, nn.ModuleList):
275+
setattr(model, module_name, nn.ModuleList())
276+
# Handle simple module attributes (e.g., "linear", "norm")
277+
elif module_name not in modules_to_keep:
278+
# Replace with identity module instead of None
279+
setattr(model, module_name, nn.Identity())
280+
281+
stage = PipelineStage(
282+
model,
283+
stage_idx,
284+
num_stages,
285+
device,
286+
group=pp_mesh.get_group("pp"),
287+
)
288+
return stage, model
289+
290+
num_stages = len(module_names_per_stage)
291+
stages = []
292+
models = []
293+
294+
schedule_class = get_schedule_class(pp_schedule)
295+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
296+
297+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
298+
module_names = module_names_per_stage[stage_idx]
299+
stage, model_chunk = _build_stage_from_modules(
300+
stage_idx,
301+
module_names,
302+
num_stages,
303+
)
304+
logger.info(
305+
f"PP rank {pp_rank} is building stage_idx {stage_idx} "
306+
f"with modules {module_names}"
307+
)
308+
stages.append(stage)
309+
models.append(model_chunk)
310+
311+
return stages, models

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
7575
n_limited_groups: int = 1
7676
score_func: Literal["softmax", "sigmoid"] = "softmax"
7777
route_scale: float = 1.0
78-
use_grouped_mm: bool = True
78+
use_grouped_mm: bool = False
7979
load_balance_coeff: float = 1e-3
8080
# Multi-Head Latent Attention (MLA)
8181
q_lora_rank: int = 0

0 commit comments

Comments
 (0)