Skip to content

[DSV3] Add PP support for DSV3 #1345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/models/deepseek_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml
- Activation checkpointing
- Tensor Parallel (TP)
- Expert Parallel (EP)
- Pipeline Parallel (PP)


## To be added
Expand All @@ -46,7 +47,6 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml
- Attention Layer: need to pass softmax_scale to sdpa() to support scaling
- Parallelism
- Context Parallel support for DeepSeek-V3
- PP support for DeepSeek-V3
- torch.compile
- Quantization
- Testing
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize import parallelize_deepseekv3
from .infra.pipeline import pipeline_deepseekv3
from .model.args import DeepSeekV3ModelArgs
from .model.model import DeepSeekV3Model

Expand Down Expand Up @@ -116,7 +117,7 @@
cls=DeepSeekV3Model,
config=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=None,
pipelining_fn=pipeline_deepseekv3,
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
Expand Down
310 changes: 310 additions & 0 deletions torchtitan/models/deepseek_v3/infra/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# This file applies the PT-D pipeline parallelism to the Llama model.

import copy

import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import (
_PipelineSchedule,
get_schedule_class,
PipelineScheduleSingle,
ScheduleZBVZeroBubble,
)

from torchtitan.components.loss import LossFunction
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
from torchtitan.tools.logging import logger

from ..model.args import DeepSeekV3ModelArgs


def generate_module_names_per_stage(
num_stages: int,
num_layers: int,
input_weight: int = 1,
output_weight: int = 1,
) -> list[list[str]]:
"""
Programmatically generates module names per stage for pipeline parallelism with weighting.
Args:
num_stages: Number of pipeline stages
num_layers: Total number of transformer layers in the model
input_weight: Weight for input modules (tok_embeddings) in layer calculation
output_weight: Weight for output modules (norm + output) in layer calculation
Returns:
List of lists containing module names for each stage
Example:
generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2)
treats embeddings as 2 layers and norm+output as 2 layers for distribution
"""
if num_stages < 1:
raise ValueError("Number of stages must be at least 1")

if num_stages == 1:
# Single stage gets everything
layer_names = [f"layers.{i}" for i in range(num_layers)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you take into consideration that the layers in dsv3 are not evenly distributed -- several dense layers, followed by MoE layers
https://github.com/pytorch/torchtitan/pull/1373/files#diff-ed005d894ae945a545c92c33136fba3bde35e70f1b7052f78242a1f69e862ab8R273

Copy link
Member Author

@H-Huang H-Huang Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't take this into consideration yet 🫤.

return [["tok_embeddings"] + layer_names + ["norm", "output"]]

# Calculate effective layers including weights
num_effective_layers = num_layers + input_weight + output_weight

if num_stages > num_effective_layers:
raise ValueError(
f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})"
)

# Calculate layers per stage (distribute evenly)
layers_per_stage = num_effective_layers // num_stages
extra_layers = num_effective_layers % num_stages

# Ensure each stage gets at least the weight of input/output modules
if layers_per_stage < max(input_weight, output_weight):
raise ValueError(
f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})"
)

module_names_per_stage = []
current_layer = 0

for stage_idx in range(num_stages):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the LoC is quite longer than https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/pipeline.py#L106-L121

as you are not generating the split points but generating actual module names per stage. (I'm assuming function-wise they are the same but I didn't check.)
But later on the complexity in pipeline_deepseekv3_module_split doesn't seem to be saved.

I'd like to learn more about the reasonings behind the change of UI.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module names are more flexible since they can be applied for all models. generate_module_names_per_stage is specific to deepseek-v3; however, the new helper method i refactored pipeline_module_split is model agnostic. So I am thinking of upstreaming this as a utility in pytorch core. With that, we can reduce the LoC needed in pipeline.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, but can user still manually specify the split via https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L313 or is it not encouraged any more?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe we should try to upstream more functions in https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/pipeline.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, but can user still manually specify the split via https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L313 or is it not encouraged any more?

I think we should move away from this

also maybe we should try to upstream more functions in https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/pipeline.py

Yeah I agree, i think this can be cleaned up, let me think of a way

stage_modules = []

# Calculate effective layers for this stage
effective_layers_for_stage = layers_per_stage
if stage_idx < extra_layers:
effective_layers_for_stage += 1

# First stage: handle input modules with weighting
if stage_idx == 0:
stage_modules.append("tok_embeddings")
# Account for input weight in layer distribution
remaining_layers_for_stage = effective_layers_for_stage - input_weight

# Add transformer layers
for _ in range(remaining_layers_for_stage):
if current_layer < num_layers:
stage_modules.append(f"layers.{current_layer}")
current_layer += 1

# Last stage: handle output modules with weighting
elif stage_idx == num_stages - 1:
# Account for output weight in layer distribution
remaining_layers_for_stage = effective_layers_for_stage - output_weight

# Add transformer layers
for _ in range(remaining_layers_for_stage):
if current_layer < num_layers:
stage_modules.append(f"layers.{current_layer}")
current_layer += 1

# Add output modules
stage_modules.extend(["norm", "output"])

# Middle stages: only transformer layers
else:
for _ in range(effective_layers_for_stage):
if current_layer < num_layers:
stage_modules.append(f"layers.{current_layer}")
current_layer += 1

module_names_per_stage.append(stage_modules)

return module_names_per_stage


def pipeline_deepseekv3(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: DeepSeekV3ModelArgs,
parallelize_fn: ParallelizeFunction,
loss_fn: LossFunction,
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
pp_mesh = world_mesh["pp"]

# Determine the number of virtual stages based on schedule type
schedule_class = get_schedule_class(
job_config.parallelism.pipeline_parallel_schedule
)
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)

# For multi-stage schedules, default is 2 virtual stages per rank
# For single-stage schedules, default is 1 virtual stage per rank
stages_per_rank = 1 if is_single_stage_schedule else 2
num_virtual_stages = parallel_dims.pp * stages_per_rank

# Generate module names per stage programmatically with weighting
num_layers = model_config.n_layers

# You can adjust these weights based on the computational cost of embeddings and output layers
# Higher weights mean these modules are treated as "heavier" in the distribution
input_weight = 1 # Weight for tok_embeddings
output_weight = 1 # Weight for norm + output layers

module_names_per_stage = generate_module_names_per_stage(
num_virtual_stages, num_layers, input_weight, output_weight
)
for i, stage_ms in enumerate(module_names_per_stage):
logger.info(f"Stage {i}: {stage_ms}")

stages, model_parts = pipeline_module_split(
model,
pp_mesh,
job_config.parallelism.pipeline_parallel_schedule,
device,
module_names_per_stage,
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage


def pipeline_module_split(
whole_model: nn.Module,
pp_mesh: DeviceMesh,
pp_schedule: str,
device: DeviceType,
module_names_per_stage: list[list[str]],
) -> tuple[list[PipelineStage], list[nn.Module]]:
"""
This API creates pipeline stages based on specified module names for each stage.
Args:
whole_model: The complete model to be split
pp_mesh: Pipeline parallel device mesh
pp_schedule: Name of pipeline parallelism schedule
device: Device type
module_names_per_stage: List of lists, where each inner list contains the module names
that should be included in that stage. Module names should be
dot-separated paths. Examples:
- "tok_embeddings" for token embeddings
- "layers.0", "layers.1" for specific transformer layers
- "norm" for the final normalization layer
- "output" for the output projection layer
Returns:
Tuple of (stages, models) where stages are PipelineStage objects and models are the
corresponding model chunks
Example usage:
module_names_per_stage = [
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
["layers.1", "layers.2"], # Stage 1: middle layers
["norm", "output"] # Stage 2: final norm + output
]
"""
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()

def _build_stage_from_modules(
stage_idx: int, module_names: list[str], num_stages: int
) -> tuple[PipelineStage, nn.Module]:
model = copy.deepcopy(whole_model)

# Create a set of modules to keep for faster lookup
modules_to_keep = set(module_names)
print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
for module_name, module_value in model.named_children():
# Handle layer-like structures (e.g., "layers.0", "layers.1")
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
layers_to_keep = {
name.split(".", 1)[1]
for name in modules_to_keep
if name.startswith(f"{module_name}.")
}
if layers_to_keep:
# Keep only specified layers
if isinstance(module_value, nn.ModuleDict):
for layer_name in list(module_value.keys()):
if layer_name not in layers_to_keep:
del module_value[layer_name]
elif isinstance(module_value, nn.ModuleList):
indices_to_keep = {
int(idx) for idx in layers_to_keep if idx.isdigit()
}
new_layers = nn.ModuleList(
[
layer
for i, layer in enumerate(module_value)
if i in indices_to_keep
]
)
setattr(model, module_name, new_layers)
else:
# No layers from this structure needed, set to empty structure
if isinstance(module_value, nn.ModuleDict):
setattr(model, module_name, nn.ModuleDict())
elif isinstance(module_value, nn.ModuleList):
setattr(model, module_name, nn.ModuleList())
# Handle simple module attributes (e.g., "linear", "norm")
elif module_name not in modules_to_keep:
# Replace with identity module instead of None
setattr(model, module_name, nn.Identity())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the advantage of doing this instead of setting to None?
I'm worried about DCP loading / resharding, as multiple PP ranks will have the same fqns.


stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
group=pp_mesh.get_group("pp"),
)
return stage, model

num_stages = len(module_names_per_stage)
stages = []
models = []

schedule_class = get_schedule_class(pp_schedule)
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"

for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
module_names = module_names_per_stage[stage_idx]
stage, model_chunk = _build_stage_from_modules(
stage_idx,
module_names,
num_stages,
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx} "
f"with modules {module_names}"
)
stages.append(stage)
models.append(model_chunk)

return stages, models
Loading
Loading