Skip to content

Commit 4d2ad19

Browse files
committed
[DSV3] Add PP support for DSV3
1 parent 1760b9f commit 4d2ad19

File tree

5 files changed

+387
-9
lines changed

5 files changed

+387
-9
lines changed

torchtitan/models/deepseek_v3/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,32 @@ Download tokenizer:
44
# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json)
55
python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3
66
```
7+
8+
Run:
9+
10+
Single GPU - debug_model
11+
```
12+
NGPU=1 LOG_RANK=0 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh
13+
```
14+
15+
FSDP:
16+
17+
```
18+
NGPU=8 LOG_RANK=0 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree 8
19+
20+
# OOM
21+
NGPU=8 LOG_RANK=0 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.data_parallel_shard_degree 8
22+
```
23+
24+
PP:
25+
26+
for additional logging use: TORCH_LOGS=+pp
27+
28+
```
29+
NGPU=2 LOG_RANK=0,1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 2
30+
31+
NGPU=4 LOG_RANK=0,1,2,3 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 4
32+
33+
# works with AC=none, but why doesn't this work with AC=full?
34+
NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 8 --parallelism.pipeline_parallel_schedule Interleaved1F1B
35+
```

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1616

1717
from .infra.parallelize import parallelize_deepseekv3
18+
from .infra.pipeline import pipeline_deepseekv3
1819
from .model.args import DeepSeekV3ModelArgs
1920
from .model.model import DeepSeekV3Model
2021

@@ -116,7 +117,7 @@
116117
cls=DeepSeekV3Model,
117118
config=deepseekv3_configs,
118119
parallelize_fn=parallelize_deepseekv3,
119-
pipelining_fn=None,
120+
pipelining_fn=pipeline_deepseekv3,
120121
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
121122
build_lr_schedulers_fn=build_lr_schedulers,
122123
build_dataloader_fn=build_hf_dataloader,
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file applies the PT-D pipeline parallelism to the Llama model.
8+
9+
import copy
10+
11+
import torch
12+
13+
import torch.nn as nn
14+
from torch.distributed import DeviceMesh
15+
from torch.distributed.pipelining import PipelineStage
16+
from torch.distributed.pipelining.schedules import (
17+
_PipelineSchedule,
18+
get_schedule_class,
19+
PipelineScheduleSingle,
20+
ScheduleZBVZeroBubble,
21+
)
22+
23+
from torchtitan.components.loss import LossFunction
24+
from torchtitan.config_manager import JobConfig
25+
from torchtitan.distributed import ParallelDims
26+
from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank
27+
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
28+
from torchtitan.tools.logging import logger
29+
30+
from ..model.args import DeepSeekV3ModelArgs
31+
32+
33+
def _pipeline_friendly_forward(self, tokens: torch.Tensor):
34+
"""
35+
Pipeline friendly forward pass for the DeepSeekV3 model.
36+
This method is only used when pipeline parallelism is enabled.
37+
If model attributes are None, they are skipped in the forward pass.
38+
39+
Args:
40+
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
41+
42+
Returns:
43+
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
44+
"""
45+
h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
46+
# h: (batch_size, seq_len, dim)
47+
for layer in self.layers.values():
48+
h = layer(h, self.freqs_cis)
49+
h = self.norm(h) if self.norm is not None else h
50+
output = self.output(h) if self.output is not None else h
51+
return output
52+
53+
54+
def _patch_model_for_pipeline(model: nn.Module):
55+
"""
56+
Patches the model's forward method to be pipeline-friendly.
57+
This only affects models used in pipeline parallelism.
58+
59+
Args:
60+
model: The model to patch
61+
"""
62+
# Store the original forward method
63+
if not hasattr(model, "_original_forward"):
64+
model._original_forward = model.forward
65+
# Replace with pipeline-friendly version
66+
model.forward = _pipeline_friendly_forward.__get__(model, model.__class__)
67+
68+
69+
def generate_module_names_per_stage(
70+
num_stages: int,
71+
num_layers: int,
72+
input_weight: int = 1,
73+
output_weight: int = 1,
74+
) -> list[list[str]]:
75+
"""
76+
Programmatically generates module names per stage for pipeline parallelism with weighting.
77+
78+
Args:
79+
num_stages: Number of pipeline stages
80+
num_layers: Total number of transformer layers in the model
81+
input_weight: Weight for input modules (tok_embeddings) in layer calculation
82+
output_weight: Weight for output modules (norm + output) in layer calculation
83+
84+
Returns:
85+
List of lists containing module names for each stage
86+
87+
Example:
88+
generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2)
89+
treats embeddings as 2 layers and norm+output as 2 layers for distribution
90+
"""
91+
if num_stages < 1:
92+
raise ValueError("Number of stages must be at least 1")
93+
94+
if num_stages == 1:
95+
# Single stage gets everything
96+
layer_names = [f"layers.{i}" for i in range(num_layers)]
97+
return [["tok_embeddings"] + layer_names + ["norm", "output"]]
98+
99+
# Calculate effective layers including weights
100+
num_effective_layers = num_layers + input_weight + output_weight
101+
102+
if num_stages > num_effective_layers:
103+
raise ValueError(
104+
f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})"
105+
)
106+
107+
# Calculate layers per stage (distribute evenly)
108+
layers_per_stage = num_effective_layers // num_stages
109+
extra_layers = num_effective_layers % num_stages
110+
111+
# Ensure each stage gets at least the weight of input/output modules
112+
if layers_per_stage < max(input_weight, output_weight):
113+
raise ValueError(
114+
f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})"
115+
)
116+
117+
module_names_per_stage = []
118+
current_layer = 0
119+
120+
for stage_idx in range(num_stages):
121+
stage_modules = []
122+
123+
# Calculate effective layers for this stage
124+
effective_layers_for_stage = layers_per_stage
125+
if stage_idx < extra_layers:
126+
effective_layers_for_stage += 1
127+
128+
# First stage: handle input modules with weighting
129+
if stage_idx == 0:
130+
stage_modules.append("tok_embeddings")
131+
# Account for input weight in layer distribution
132+
remaining_layers_for_stage = effective_layers_for_stage - input_weight
133+
134+
# Add transformer layers
135+
for _ in range(remaining_layers_for_stage):
136+
if current_layer < num_layers:
137+
stage_modules.append(f"layers.{current_layer}")
138+
current_layer += 1
139+
140+
# Last stage: handle output modules with weighting
141+
elif stage_idx == num_stages - 1:
142+
# Account for output weight in layer distribution
143+
remaining_layers_for_stage = effective_layers_for_stage - output_weight
144+
145+
# Add transformer layers
146+
for _ in range(remaining_layers_for_stage):
147+
if current_layer < num_layers:
148+
stage_modules.append(f"layers.{current_layer}")
149+
current_layer += 1
150+
151+
# Add output modules
152+
stage_modules.extend(["norm", "output"])
153+
154+
# Middle stages: only transformer layers
155+
else:
156+
for _ in range(effective_layers_for_stage):
157+
if current_layer < num_layers:
158+
stage_modules.append(f"layers.{current_layer}")
159+
current_layer += 1
160+
161+
module_names_per_stage.append(stage_modules)
162+
163+
return module_names_per_stage
164+
165+
166+
def pipeline_deepseekv3(
167+
model: nn.Module,
168+
world_mesh: DeviceMesh,
169+
parallel_dims: ParallelDims,
170+
job_config: JobConfig,
171+
device: DeviceType,
172+
model_config: DeepSeekV3ModelArgs,
173+
parallelize_fn: ParallelizeFunction,
174+
loss_fn: LossFunction,
175+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
176+
pp_mesh = world_mesh["pp"]
177+
178+
# Determine the number of virtual stages based on schedule type
179+
schedule_class = get_schedule_class(
180+
job_config.parallelism.pipeline_parallel_schedule
181+
)
182+
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
183+
184+
# For multi-stage schedules, default is 2 virtual stages per rank
185+
# For single-stage schedules, default is 1 virtual stage per rank
186+
stages_per_rank = 1 if is_single_stage_schedule else 2
187+
num_virtual_stages = parallel_dims.pp * stages_per_rank
188+
189+
# Generate module names per stage programmatically with weighting
190+
num_layers = model_config.n_layers
191+
192+
# You can adjust these weights based on the computational cost of embeddings and output layers
193+
# Higher weights mean these modules are treated as "heavier" in the distribution
194+
input_weight = 1 # Weight for tok_embeddings
195+
output_weight = 1 # Weight for norm + output layers
196+
197+
module_names_per_stage = generate_module_names_per_stage(
198+
num_virtual_stages, num_layers, input_weight, output_weight
199+
)
200+
for i, stage_ms in enumerate(module_names_per_stage):
201+
logger.info(f"Stage {i}: {stage_ms}")
202+
203+
stages, model_parts = pipeline_module_split(
204+
model,
205+
pp_mesh,
206+
job_config.parallelism.pipeline_parallel_schedule,
207+
device,
208+
module_names_per_stage,
209+
)
210+
211+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
212+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
213+
# optimizer, and checkpointing
214+
for i, m in enumerate(model_parts):
215+
# apply SPMD-style PT-D techniques
216+
m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
217+
model_parts[i] = m
218+
# NOTE: this is to update the model in the stage
219+
# in case the model is modified e.g. by torch.compile
220+
stages[i].submod = m
221+
222+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
223+
224+
# This is used in the train loop to determine whether to pass in the input_ids and labels
225+
has_first_stage = False
226+
has_last_stage = False
227+
for stage in stages:
228+
if stage.is_first:
229+
has_first_stage = True
230+
if stage.is_last:
231+
has_last_stage = True
232+
233+
return pp_schedule, model_parts, has_first_stage, has_last_stage
234+
235+
236+
def pipeline_module_split(
237+
whole_model: nn.Module,
238+
pp_mesh: DeviceMesh,
239+
pp_schedule: str,
240+
device: DeviceType,
241+
module_names_per_stage: list[list[str]],
242+
) -> tuple[list[PipelineStage], list[nn.Module]]:
243+
"""
244+
This API creates pipeline stages based on specified module names for each stage.
245+
246+
Args:
247+
whole_model: The complete model to be split
248+
pp_mesh: Pipeline parallel device mesh
249+
pp_schedule: Name of pipeline parallelism schedule
250+
device: Device type
251+
module_names_per_stage: List of lists, where each inner list contains the module names
252+
that should be included in that stage. Module names should be
253+
dot-separated paths. Examples:
254+
- "tok_embeddings" for token embeddings
255+
- "layers.0", "layers.1" for specific transformer layers
256+
- "norm" for the final normalization layer
257+
- "output" for the output projection layer
258+
259+
Returns:
260+
Tuple of (stages, models) where stages are PipelineStage objects and models are the
261+
corresponding model chunks
262+
263+
Example usage:
264+
module_names_per_stage = [
265+
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
266+
["layers.1", "layers.2"], # Stage 1: middle layers
267+
["norm", "output"] # Stage 2: final norm + output
268+
]
269+
"""
270+
pp_rank = pp_mesh.get_local_rank()
271+
pp_size = pp_mesh.size()
272+
273+
def _build_stage_from_modules(
274+
stage_idx: int, module_names: list[str], num_stages: int
275+
) -> tuple[PipelineStage, nn.Module]:
276+
model = copy.deepcopy(whole_model)
277+
# Patch the model to use pipeline-friendly forward method
278+
_patch_model_for_pipeline(model)
279+
280+
# Create a set of modules to keep for faster lookup
281+
modules_to_keep = set(module_names)
282+
print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
283+
for module_name, module_value in model.named_children():
284+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
285+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
286+
layers_to_keep = {
287+
name.split(".", 1)[1]
288+
for name in modules_to_keep
289+
if name.startswith(f"{module_name}.")
290+
}
291+
if layers_to_keep:
292+
# Keep only specified layers
293+
if isinstance(module_value, nn.ModuleDict):
294+
for layer_name in list(module_value.keys()):
295+
if layer_name not in layers_to_keep:
296+
del module_value[layer_name]
297+
elif isinstance(module_value, nn.ModuleList):
298+
indices_to_keep = {
299+
int(idx) for idx in layers_to_keep if idx.isdigit()
300+
}
301+
new_layers = nn.ModuleList(
302+
[
303+
layer
304+
for i, layer in enumerate(module_value)
305+
if i in indices_to_keep
306+
]
307+
)
308+
setattr(model, module_name, new_layers)
309+
else:
310+
# No layers from this structure needed, set to empty structure
311+
if isinstance(module_value, nn.ModuleDict):
312+
setattr(model, module_name, nn.ModuleDict())
313+
elif isinstance(module_value, nn.ModuleList):
314+
setattr(model, module_name, nn.ModuleList())
315+
# Handle simple module attributes (e.g., "linear", "norm")
316+
elif module_name not in modules_to_keep:
317+
setattr(model, module_name, None)
318+
319+
stage = PipelineStage(
320+
model,
321+
stage_idx,
322+
num_stages,
323+
device,
324+
group=pp_mesh.get_group("pp"),
325+
)
326+
return stage, model
327+
328+
num_stages = len(module_names_per_stage)
329+
stages = []
330+
models = []
331+
332+
schedule_class = get_schedule_class(pp_schedule)
333+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
334+
335+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
336+
module_names = module_names_per_stage[stage_idx]
337+
stage, model_chunk = _build_stage_from_modules(
338+
stage_idx,
339+
module_names,
340+
num_stages,
341+
)
342+
logger.info(
343+
f"PP rank {pp_rank} is building stage_idx {stage_idx} "
344+
f"with modules {module_names}"
345+
)
346+
stages.append(stage)
347+
models.append(model_chunk)
348+
349+
return stages, models

0 commit comments

Comments
 (0)