diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 107bd0481..367e4e941 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -38,6 +38,7 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml - Activation checkpointing - Tensor Parallel (TP) - Expert Parallel (EP) +- Pipeline Parallel (PP) ## To be added @@ -46,7 +47,6 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml - Attention Layer: need to pass softmax_scale to sdpa() to support scaling - Parallelism - Context Parallel support for DeepSeek-V3 - - PP support for DeepSeek-V3 - torch.compile - Quantization - Testing diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 141b740ce..de2d26b8a 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -15,6 +15,7 @@ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_deepseekv3 +from .infra.pipeline import pipeline_deepseekv3 from .model.args import DeepSeekV3ModelArgs from .model.model import DeepSeekV3Model @@ -116,7 +117,7 @@ model_cls=DeepSeekV3Model, model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, - pipelining_fn=None, + pipelining_fn=pipeline_deepseekv3, build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, diff --git a/torchtitan/models/deepseek_v3/infra/pipeline.py b/torchtitan/models/deepseek_v3/infra/pipeline.py new file mode 100644 index 000000000..7caf3ad81 --- /dev/null +++ b/torchtitan/models/deepseek_v3/infra/pipeline.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D pipeline parallelism to the Llama model. + +import copy + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + PipelineScheduleSingle, + ScheduleZBVZeroBubble, +) + +from torchtitan.components.loss import LossFunction +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank +from torchtitan.protocols.train_spec import ParallelizeFunction +from torchtitan.tools.logging import logger + +from ..model.args import DeepSeekV3ModelArgs + + +def generate_module_names_per_stage( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names per stage for pipeline parallelism with weighting. + + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (tok_embeddings) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + + Returns: + List of lists containing module names for each stage + + Example: + generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Ensure each stage gets at least the weight of input/output modules + if layers_per_stage < max(input_weight, output_weight): + raise ValueError( + f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})" + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_deepseekv3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_config: DeepSeekV3ModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + pp_mesh = parallel_dims.world_mesh["pp"] + + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + # Generate module names per stage programmatically with weighting + num_layers = model_config.n_layers + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = 1 # Weight for tok_embeddings + output_weight = 1 # Weight for norm + output layers + + module_names_per_stage = generate_module_names_per_stage( + num_virtual_stages, num_layers, input_weight, output_weight + ) + for i, stage_ms in enumerate(module_names_per_stage): + logger.info(f"Stage {i}: {stage_ms}") + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device type + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}") + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(model, module_name, None) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 61034e4c7..9d7c336e6 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -9,7 +9,7 @@ import torch from torch import nn -from torchtitan.models.attention import build_attention +from torchtitan.models.attention import build_attention, init_attention_mask from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs @@ -357,20 +357,32 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: b=cutoff_factor * final_out_std, ) - def forward(self, tokens: torch.Tensor): + def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): """ Forward pass for the Transformer model. Args: - tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + input_batch (torch.Tensor): The input batch read from the dataloader. + This will always be the input batch regardless of the pipeline stage. + This field is required for non-first PP stages to perform document + masking attention (to analyze the boundary of the document). Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ - h = self.tok_embeddings(tokens) + if self.model_args.use_flex_attn: + init_attention_mask( + input_batch if input_batch is not None else tokens, eos_id=self.eos_id + ) + + h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens for layer in self.layers.values(): h = layer(h, self.freqs_cis) - h = self.norm(h) - output = self.output(h) + h = self.norm(h) if self.norm is not None else h + output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 905aa0067..5f66ff4c3 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -50,9 +50,11 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false expert_parallel_degree = 1 +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 84c6b5f6b..a9316b548 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -51,6 +51,8 @@ fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false expert_parallel_degree = 1 +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "Interleaved1F1B" [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 26cb64fb7..b3722c08b 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -51,6 +51,8 @@ fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 8 enable_async_tensor_parallel = false expert_parallel_degree = 1 +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "Interleaved1F1B" [checkpoint] enable_checkpoint = false