From c4e1798e946ef543569aa9e862317460d2b9f103 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 25 Jun 2025 22:51:48 -0700 Subject: [PATCH 01/12] dp2ep Expert Parallel --- docs/checkpoint.md | 2 +- scripts/estimate/estimation.py | 8 +- scripts/generate/test_generate.py | 1 + tests/unit_tests/test_model_converter.py | 1 + torchtitan/components/ft.py | 37 --- torchtitan/components/optimizer.py | 123 ++++++++- torchtitan/config_manager.py | 8 + torchtitan/distributed/parallel_dims.py | 104 +++++++- torchtitan/distributed/utils.py | 70 ++++- torchtitan/experiments/llama4/README.md | 2 +- torchtitan/experiments/llama4/__init__.py | 4 +- .../llama4/infra/expert_parallel.py | 249 ++++++++++++++---- .../experiments/llama4/infra/parallelize.py | 202 ++++++++++---- torchtitan/experiments/llama4/model/moe.py | 249 +++++++++--------- torchtitan/experiments/llama4/optimizer.py | 68 +++++ .../llama4/train_configs/debug_model.toml | 4 +- torchtitan/models/llama3/infra/parallelize.py | 5 +- torchtitan/protocols/train_spec.py | 5 +- torchtitan/train.py | 4 +- 19 files changed, 842 insertions(+), 304 deletions(-) create mode 100644 torchtitan/experiments/llama4/optimizer.py diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 0dad44e67..5275db1a2 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -83,5 +83,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l To create a seed checkpoint, use the same model config as you use for training. e.g. ```bash -NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 +NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 9dac8f17b..33d4dc17f 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -46,6 +46,7 @@ def estimate_memory(job_config: JobConfig): cp=parallelism_config.context_parallel_degree, tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) @@ -56,8 +57,9 @@ def estimate_memory(job_config: JobConfig): or parallel_dims.tp_enabled or parallel_dims.pp_enabled or parallel_dims.cp_enabled + or parallel_dims.ep_enabled ): - logger.warning("DDP, TP, PP, CP are not supported yet.") + logger.warning("DDP, TP, PP, CP, EP are not supported yet.") return if not parallel_dims.dp_shard_enabled: logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.") @@ -115,7 +117,9 @@ def estimate_memory(job_config: JobConfig): # build optimizer after applying parallelisms to the model ft_manager = init_ft_manager(job_config) - optimizers = build_optimizers([model], job_config, ft_manager) + optimizers = build_optimizers( + [model], job_config, parallel_dims, world_mesh, ft_manager + ) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 636d10a51..112aacd40 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -125,6 +125,7 @@ def test_generate( cp=1, tp=world_size, pp=1, + ep=1, world_size=world_size, enable_loss_parallel=False, ) diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py index bb5a1ecc4..704e81a91 100644 --- a/tests/unit_tests/test_model_converter.py +++ b/tests/unit_tests/test_model_converter.py @@ -21,6 +21,7 @@ def build_parallel_dims(job_config, world_size): cp=parallelism_config.context_parallel_degree, tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py index 11f13df11..946fc4638 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft.py @@ -7,7 +7,6 @@ import copy import importlib from contextlib import nullcontext -from dataclasses import dataclass from typing import ContextManager, Optional, TYPE_CHECKING, Union import torch @@ -18,7 +17,6 @@ from torch.distributed.distributed_c10d import ReduceOp from torch.distributed.tensor import DTensor from torchtitan.config_manager import JobConfig -from torchtitan.distributed import ParallelDims if importlib.util.find_spec("torchft") is not None: import torchft as ft @@ -106,41 +104,6 @@ def init_ft_manager(job: JobConfig) -> FTManager: ) -@dataclass -class FTParallelDims(ParallelDims): - ft_manager: FTManager - - def build_mesh(self, device_type: str) -> DeviceMesh: - def func( - device_type: str, mesh_shape: list[int], mesh_dim_names: list[str] - ) -> DeviceMesh: - from torchft.process_group import ft_init_device_mesh - - return ft_init_device_mesh( - device_type=device_type, - mesh_shape=mesh_shape, - mesh_dim_names=mesh_dim_names, - replicate_dim=mesh_dim_names.index("dp_replicate"), - manager=self.ft_manager.manager, - ) - - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1 or name == "dp_replicate": - dims.append(d) - names.append(name) - - return self._build_mesh(device_type, dims, names, func) - - @property - def dp_replicate_enabled(self): - return True - - def ft_dist_reduce( x: torch.Tensor, reduceOp: str, mesh: DeviceMesh ) -> tuple[torch.Tensor, str, DeviceMesh]: diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index a88f33fa1..171101361 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import functools +from itertools import chain from typing import Any, Generic, Iterator, TypeVar import torch @@ -15,10 +16,13 @@ StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims __all__ = [ "OptimizersContainer", @@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None: super().zero_grad(*args, **kwargs) +class ExpertParallelOptimizersContainer(OptimizersContainer): + """ + This class is created to support fused optimizer implementation for Expert Parallel. + Since in EP, not all the parameters are sharded on the same DeviceMesh, the base + OptimizersContainer cannot perform fused optimizer steps on all DTensor parameters. + In this class, we create two optimizers for each model part, one for ep params and the + other for non-ep params. Parameters in the same optimizer are always on the same DeviceMesh, + so that fused optimizer can be performed. + """ + + def __init__( + self, + model_parts: list[nn.Module], + optimizer_cls: type[T], + optimizer_kwargs: dict[str, Any], + dense_params_mesh_ndim: int, + ) -> None: + ep_params, non_ep_params = [], [] + self.ep_optimizers = [] + self.non_ep_optimizers = [] + + self.model_parts = model_parts + # This is still needed to + # 1. reuse other OptimizersContainer's methods other than state dict save / load + # 2. define LR schedulers + self.optimizers = [] + + for model in self.model_parts: + for p in model.parameters(): + if not p.requires_grad: + continue + assert isinstance(p, DTensor) + if p.device_mesh.ndim == dense_params_mesh_ndim: + non_ep_params.append(p) + else: + ep_params.append(p) + + ep_optimizer = optimizer_cls(ep_params, **optimizer_kwargs) + non_ep_optimizers = optimizer_cls(non_ep_params, **optimizer_kwargs) + self.ep_optimizers.append(ep_optimizer) + self.non_ep_optimizers.append(non_ep_optimizers) + self.optimizers.append(ep_optimizer) + self.optimizers.append(non_ep_optimizers) + + # NOTE: each model part has two optimizers, one for ep params + # and the other for non-ep params + self._validate_length(len(self.model_parts) * 2) + self._post_init(ep_params, optimizer_kwargs) + self._post_init(non_ep_params, optimizer_kwargs) + + def state_dict(self) -> dict[str, Any]: + func = functools.partial( + get_optimizer_state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + return { + k: v + for sd in chain( + map(func, self.model_parts, self.ep_optimizers), + map(func, self.model_parts, self.non_ep_optimizers), + ) + for k, v in sd.items() + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + func = functools.partial( + set_optimizer_state_dict, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + list(map(func, self.model_parts, self.ep_optimizers)) + list(map(func, self.model_parts, self.non_ep_optimizers)) + + def build_optimizers( model_parts: list[nn.Module], job_config: JobConfig, + parallel_dims: ParallelDims, + world_mesh: DeviceMesh, ft_manager: FTManager, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. @@ -259,12 +339,23 @@ def build_optimizers( Args: model_parts (List[nn.Module]): List of model parts to be optimized. job_config (JobConfig): Job config containing the optimizer name and parameters. + parallel_dims (ParallelDims): Parallel dimensions for the model. """ optim_in_bwd = job_config.optimizer.early_step_in_backward - if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1: - raise NotImplementedError( - "Optimizers in backward is not supported with pipeline parallelism." - ) + if optim_in_bwd: + if parallel_dims.ep_enabled: + raise NotImplementedError( + "Optimizers in backward is not supported with Expert Parallel." + ) + if parallel_dims.pp_enabled: + raise NotImplementedError( + "Optimizers in backward is not supported with Pipeline Parallel." + ) + if ft_manager.enabled: + raise NotImplementedError( + "TorchFT is not supported with optimizers in backward." + ) + name = job_config.optimizer.name lr = job_config.optimizer.lr beta1 = job_config.optimizer.beta1 @@ -295,13 +386,12 @@ def build_optimizers( raise NotImplementedError(f"Optimizer {name} not added.") optimizer_cls = optimizer_classes[name] - if optim_in_bwd and ft_manager.enabled: - raise ValueError("TorchFT is not supported with optimizers in backward.") - elif optim_in_bwd: + if optim_in_bwd: return OptimizersInBackwardContainer( model_parts, optimizer_cls, optimizer_kwargs ) - elif ft_manager.enabled: + + if ft_manager.enabled: return FTOptimizersContainer( model_parts, optimizer_cls, @@ -309,5 +399,18 @@ def build_optimizers( ft_manager.manager, use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None, ) - else: - return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + + if parallel_dims.ep_enabled and fused: + if ft_manager.enabled: + raise NotImplementedError( + "Expert Parallel with fused optimizer implementation " + "is not supported with TorchFT yet." + ) + return ExpertParallelOptimizersContainer( + model_parts, + optimizer_cls, + optimizer_kwargs, + parallel_dims.dense_params_mesh_ndim, + ) + + return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3f8d25688..d40a5982f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -363,6 +363,14 @@ class Parallelism: The default value is 'allgather'. """ + expert_parallel_degree: int = 1 + """ + Expert parallelism degree. 1 means disabled. + Currently, only "dp2ep" is supported, with the following constraints: + context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree + Note that this is still an experimental feature. + """ + @dataclass class Checkpoint: diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 07a056fc4..08986b220 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections.abc import Callable from dataclasses import dataclass from functools import cached_property @@ -23,6 +22,7 @@ class ParallelDims: cp: int tp: int pp: int + ep: int world_size: int enable_loss_parallel: bool @@ -30,14 +30,15 @@ def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, cp, tp, pp = ( + dp_replicate, dp_shard, cp, tp, pp, ep = ( self.dp_replicate, self.dp_shard, self.cp, self.tp, self.pp, + self.ep, ) - for d in (dp_replicate, cp, tp, pp): + for d in (dp_replicate, cp, tp, pp, ep): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." @@ -50,7 +51,84 @@ def _validate(self): f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" ) + if ep > 1: + # EP would borrow all cp and some dp_shard degree + assert ep % cp == 0 and (dp_shard * cp) % ep == 0 + def build_mesh(self, device_type: str) -> DeviceMesh: + # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel + # is not very clean, due to the limited support from DeviceMesh + # for creating two staggered meshes. Will improve. + if self.ep > 1: + return self._build_mesh_with_ep(device_type) + else: + return self._build_mesh_without_ep(device_type) + + def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh: + # With ep, dp_shard and ep are derived submeshes: + # dp_shard = dp_shard_mod_ep * dp_shard_in_ep + # ep = dp_shard_in_ep * cp + dp_shard_mod_ep = self.dp_shard * self.cp // self.ep + dp_shard_in_ep = self.ep // self.cp + + dims = [] + names = [] + for d, name in zip( + [ + self.pp, + self.dp_replicate, + dp_shard_mod_ep, + dp_shard_in_ep, + self.cp, + self.tp, + ], + ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], + ): + # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping + # helps the MoE layers do mixed precision training + if d > 1 or name == "dp_shard_mod_ep": + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + # Mesh for ep + ep_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + # dp_shard_mod_ep is always needed, even if it's 1 + dp_mesh_dim_names.append("dp_shard_mod_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") + dp_cp_mesh_dim_names.append("dp_shard_mod_ep") + if "dp_shard_in_ep" in names: + dp_mesh_dim_names.append("dp_shard_in_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") + dp_cp_mesh_dim_names.append("dp_shard_in_ep") + ep_mesh_dim_names.append("dp_shard_in_ep") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + ep_mesh_dim_names.append("cp") + + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + + return mesh + + def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh: dims = [] names = [] for d, name in zip( @@ -61,17 +139,8 @@ def build_mesh(self, device_type: str) -> DeviceMesh: dims.append(d) names.append(name) - return self._build_mesh(device_type, dims, names, init_device_mesh) - - def _build_mesh( - self, - device_type: str, - dims: list[int], - names: list[str], - init_device_mesh_fn: Callable, - ) -> DeviceMesh: logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) # Create all the submesh here to ensure all required process groups are # initialized: @@ -143,3 +212,12 @@ def loss_parallel_enabled(self): @cached_property def non_data_parallel_size(self): return self.cp * self.tp * self.pp + + @property + def ep_enabled(self): + return self.ep > 1 + + @property + def dense_params_mesh_ndim(self): + # Note: EP params mesh ndim is 1 more due to the 'ep' mesh + return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 9c86fba1b..3f824d5fe 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -307,6 +307,7 @@ def clip_grad_norm_( error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: DeviceMesh | None = None, + parallel_dims: ParallelDims | None = None, ) -> torch.Tensor: """ Clip the gradient norm of an iterable of parameters. @@ -329,11 +330,23 @@ def clip_grad_norm_( fall back to the slow implementation for other device types. Default: ``None`` pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages. + parallel_dims: ParallelDims object which contains Expert Parallel related info. Returns: Total norm of the parameter gradients (viewed as a single vector). """ + if parallel_dims and parallel_dims.ep_enabled: + return _clip_grad_norm_with_ep( + parameters, + max_norm, + norm_type, + error_if_nonfinite, + foreach, + pp_mesh, + parallel_dims, + ) + if isinstance(parameters, torch.Tensor): parameters = [parameters] else: @@ -353,7 +366,6 @@ def clip_grad_norm_( if isinstance(total_norm, DTensor): # Will reach here if any non-PP parallelism is used. # If only using PP, total_norm will be a local tensor. - total_norm = total_norm.full_tensor() if pp_mesh is not None: @@ -366,3 +378,59 @@ def clip_grad_norm_( torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm + + +@torch.no_grad() +def _clip_grad_norm_with_ep( + parameters: torch.Tensor | Iterable[torch.Tensor], + max_norm: float, + norm_type: float, + error_if_nonfinite: bool, + foreach: bool | None, + pp_mesh: DeviceMesh | None, + parallel_dims: ParallelDims, +) -> torch.Tensor: + assert parallel_dims.ep_enabled + + ep_params = [] + non_ep_params = [] + ep_grads = [] + non_ep_grads = [] + + for p in parameters: + if p.grad is None: + continue + assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) + if p.device_mesh.ndim == parallel_dims.dense_params_mesh_ndim: + non_ep_params.append(p) + non_ep_grads.append(p.grad) + else: + ep_params.append(p) + ep_grads.append(p.grad) + ep_grads_total_norm = torch.nn.utils.get_total_norm( + ep_grads, norm_type, error_if_nonfinite, foreach + ).full_tensor() + non_ep_grads_total_norm = torch.nn.utils.get_total_norm( + non_ep_grads, norm_type, error_if_nonfinite, foreach + ).full_tensor() + + if math.isinf(norm_type): + total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) + else: + total_norm = ( + ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type + ) + total_norm **= 1.0 / norm_type + + if pp_mesh is not None: + if math.isinf(norm_type): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= norm_type + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / norm_type + + torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach) + torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach) + + return total_norm diff --git a/torchtitan/experiments/llama4/README.md b/torchtitan/experiments/llama4/README.md index eb9adbb31..4b42f7c3f 100644 --- a/torchtitan/experiments/llama4/README.md +++ b/torchtitan/experiments/llama4/README.md @@ -6,6 +6,7 @@ https://github.com/pytorch/torchtitan/issues/1118 #### Available features - Llama 4 model (text-only), including a token-choice MoE architecture with efficient bfloat16 Grouped MM kernels and auxiliary-loss-free load balancing - FSDP, TP, PP, CP support +- Expert Parallel support - DCP checkpoint conversion scripts #### Download Llama 4 tokenizer @@ -20,7 +21,6 @@ python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E - multimodal support - Parallelism - Context Parallel support for FlexAttention and multimodal inputs - - Expert Parallel support - torch.compile - for MoE layers - Quantization diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index a3fb29375..329c4e9d7 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -6,7 +6,6 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers -from torchtitan.components.optimizer import build_optimizers from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer from torchtitan.models.llama3 import pipeline_llama @@ -15,6 +14,7 @@ from .infra.parallelize import parallelize_llama from .model.args import TransformerModelArgs from .model.model import Transformer +from .optimizer import build_llama4_optimizers __all__ = [ "TransformerModelArgs", @@ -98,7 +98,7 @@ config=llama4_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, - build_optimizers_fn=build_optimizers, + build_optimizers_fn=build_llama4_optimizers, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_tiktoken_tokenizer, diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 68f2e7a75..0e8aef8ee 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -6,10 +6,12 @@ from functools import partial -from typing import Optional, Tuple +from typing import Callable import torch +import torch.distributed as dist import torch.nn as nn +from torch.distributed._functional_collectives import all_to_all_single_autograd from torch.distributed.tensor import ( DeviceMesh, distribute_module, @@ -24,40 +26,6 @@ # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): - def __init__( - self, - *, - input_layouts: Optional[Tuple[Optional[Placement]]] = None, - output_layout: Optional[Placement] = None, - use_local_output: bool = True, - ): - super().__init__() - self.input_layouts = input_layouts or (Replicate(), Replicate()) - self.output_layout = output_layout or Replicate() - self.desired_input_layouts = (Replicate(), Replicate()) - self.use_local_output = use_local_output - - @staticmethod - def _prepare_input_fn( - input_layouts, desired_input_layouts, mod, inputs, device_mesh - ): - prepared_inputs = [] - # annotate module input placements/sharding with input_layouts - for inp, input_layout, desired_input_layout in zip( - inputs, input_layouts, desired_input_layouts - ): - if isinstance(inp, torch.Tensor): - if not isinstance(inp, DTensor): - inp = DTensor.from_local( - inp, device_mesh, (input_layout,), run_check=False - ) - if input_layout != desired_input_layout: - inp = inp.redistribute( - placements=(desired_input_layout,), async_op=True - ) - prepared_inputs.append(inp) - return tuple(prepared_inputs) - def _partition_fn(self, name, module, device_mesh): module.register_parameter( "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) @@ -71,36 +39,25 @@ def _partition_fn(self, name, module, device_mesh): nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])), ) # Column-wise sharding - @staticmethod - def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): - if outputs.placements != (output_layout,): - outputs = outputs.redistribute(placements=(output_layout,), async_op=True) - # back to local tensor - return outputs.to_local() if use_local_output else outputs - def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, self._partition_fn, - partial( - self._prepare_input_fn, self.input_layouts, self.desired_input_layouts - ), - partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) # NOTE: This is to achieve replicate computation on the gate module in the MoE router. # It does nothing other than (1) setting the module parameters as DTensors on the given mesh # and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. -# TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, # which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. class NoParallel(ParallelStyle): def __init__( self, *, - input_layout: Optional[Placement] = None, - output_layout: Optional[Placement] = None, + input_layout: Placement | None = None, + output_layout: Placement | None = None, use_local_output: bool = True, ): super().__init__() @@ -141,3 +98,197 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ), partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) + + +class ExpertParallel(ParallelStyle): + def __init__(self): + super().__init__() + self.input_splits = None + self.output_splits = None + + # performing all-to-all dispatch on the input + def _token_dispatch(self, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + routed_input, num_tokens_per_expert = inputs + + # generate the input splits and output splits for all-to-all + with torch.no_grad(): + num_tokens_per_expert_group = num_tokens_per_expert.new_empty( + num_tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + num_tokens_per_expert_group, + num_tokens_per_expert, + group=device_mesh.get_group(), + ) + # NOTE: this would incur a device-to-host sync + self.input_splits = ( + num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() + ) + self.output_splits = ( + num_tokens_per_expert_group.view(device_mesh.shape[0], -1) + .sum(dim=1) + .tolist() + ) + + # perform all-to-all + routed_input = all_to_all_single_autograd( + routed_input, + self.output_splits, + self.input_splits, + device_mesh.get_group(), + ) + + # NOTE: After this all-to-all, the routed input is put on proper EP rank. + # However, the num_tokens_per_expert_group is not of the final target format + # [#tokens for local expert 0, #tokens for local expert 1, ...] + # Rather, it is of the format + # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., + # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] + # We need to perform another shuffle to get the correct format -- this is done via the function + # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens + # each expert gets locally is a multiple of ALIGN_SIZE_M. + + return routed_input, num_tokens_per_expert_group + + @staticmethod + def _partition_fn(name, mod, device_mesh): + # shard on the expert dimension + for name, param in mod.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + mod.register_parameter(name, dist_param) + + # performing all-to-all combine on the output + def _token_combine(self, mod, routed_output, device_mesh): + routed_output = all_to_all_single_autograd( + routed_output, + self.input_splits, + self.output_splits, + device_mesh.get_group(), + ) + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=ExpertParallel._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +# This class is for dp2ep with TP (without TP we can just use ExpertParallel) +class ExpertTensorParallel(ExpertParallel): + def __init__( + self, + tp_mesh: DeviceMesh, + ep_mesh: DeviceMesh, + ): + super().__init__() + # TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, + # as DeviceMesh doesn't support slicing from a submesh. + self.tp_mesh = tp_mesh + self.ep_mesh = ep_mesh + + def _token_dispatch(self, mod, inputs, device_mesh): + # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_dispatch(mod, inputs, self.ep_mesh) + + def _partition_fn_2d(self, name, mod, ep_tp_mesh): + mod.register_parameter( + "w1", + nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + mod.register_parameter( + "w2", + nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(1)])), + ) # Row-wise sharding + mod.register_parameter( + "w3", + nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + + def _token_combine(self, mod, routed_output, device_mesh): + # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_combine(mod, routed_output, self.ep_mesh) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn_2d, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +def expert_parallel(func: Callable) -> Callable: + """ + This is a wrapper applied to the GroupedExperts computation, serving + the following three purposes: + 1. Convert parameters from DTensors to plain Tensors, to work with + dynamic-shape inputs which cannot be easily expressed as DTensors. + 2. In Expert Parallel, apply the generate_permute_indices kernel to + permute the inputs to be ordered by local experts (see the _token_dispatch + function in ExpertParallel) and permute the outputs back. + 3. In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices + kernel also helps achieve this via padding, without incurring synchronization + between device and host. Note that this will create side effects when wrapping + the for-loop implementation of GroupedExperts, as it does not need padding. + + Among the above: + 1 and 2 are needed only when expert_parallel_degree > 1. + 3 is needed even for single-device computation. + 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + """ + + def wrapper( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(w1, DTensor): + w1 = w1.to_local() + w2 = w2.to_local() + w3 = w3.to_local() + + if num_tokens_per_expert is not None: + from torchtitan.experiments.kernels.moe.indices import ( + generate_permute_indices, + ) + + experts_per_ep_rank = w1.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + ALIGN_SIZE_M = 16 + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, + ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + out = func(w1, w2, w3, x, num_tokens_per_expert) + + if num_tokens_per_expert is not None: + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + + return out + + return wrapper diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index c8a968b1c..9ed3d112d 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -8,6 +8,12 @@ import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Partial, Replicate, Shard +from torch.distributed.tensor.parallel import ( + parallelize_module, + PrepareModuleInputOutput, +) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims @@ -16,12 +22,16 @@ apply_ac, apply_compile, apply_ddp, - apply_fsdp, apply_tp, ) from torchtitan.tools.logging import logger -from ..model.moe import MoE +from .expert_parallel import ( + ExpertParallel, + ExpertTensorParallel, + NoParallel, + TensorParallel, +) def parallelize_llama( @@ -64,7 +74,18 @@ def parallelize_llama( enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) - apply_moe_tp(model, world_mesh["tp"]) + # TODO: support float8 TP for MoE layers + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled and parallel_dims.ep_enabled + else None + ), + ) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -77,15 +98,21 @@ def parallelize_llama( torch._dynamo.config.capture_scalar_outputs = True dp_mesh: DeviceMesh | None = None - if ( - parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled - ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: dp_mesh_dim_names = ("dp_shard_cp",) dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + apply_fsdp( model, dp_mesh, @@ -94,6 +121,11 @@ def parallelize_llama( pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if dp_mod_ep_mesh_dim_names + else None + ), ) if parallel_dims.dp_replicate_enabled: @@ -117,64 +149,126 @@ def parallelize_llama( enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) - # for MoE auxiliary-loss-free load balancing - if parallel_dims.dp_cp_enabled: - # NOTE: Currently this sync is blocking (thus exposed) and happens on the - # default compute stream. Need to assess if this is OK performance-wise. - dp_cp_mesh = world_mesh["dp_cp"] + return model + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", + dp_mod_ep_mesh: DeviceMesh | None = None, +): + """ + Apply data parallelism (via FSDP2) to the model. - def _sync_tokens_per_expert(module, *_): - assert isinstance(module, MoE) - torch.distributed.all_reduce( - module.tokens_per_expert, group=dp_cp_mesh.get_group() + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + for layer_id, transformer_block in model.layers.items(): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.layers) - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." ) - for transformer_block in model.layers.values(): - if transformer_block.moe_enabled: - load_balance_coeff = transformer_block.moe.load_balance_coeff - if load_balance_coeff is not None and load_balance_coeff > 0: - # prepend=True so that the sync runs before - # the _update_expert_bias hook in MoE - transformer_block.moe.register_full_backward_hook( - _sync_tokens_per_expert, prepend=True - ) - else: - break + # NOTE: in an MoE layer, the router and the shared experts + # are sharded together with the TransformerBlock + if transformer_block.moe_enabled and dp_mod_ep_mesh: + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + fully_shard( + transformer_block.moe.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + ) - return model + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) -def apply_moe_tp( +def apply_moe_ep_tp( model: nn.Module, - tp_mesh: DeviceMesh, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, ): - from torch.distributed.tensor import Partial, Replicate, Shard - from torch.distributed.tensor.parallel import ( - parallelize_module, - PrepareModuleInputOutput, - ) + for transformer_block in model.layers.values(): + if not transformer_block.moe_enabled: + continue - from .expert_parallel import NoParallel, TensorParallel + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) - for transformer_block in model.layers.values(): - moe_layer_plan = { - # input / output sharding on the seqlen dim - # all-gather for input, reduce-scatter for output - "moe": PrepareModuleInputOutput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - use_local_input=True, - output_layouts=(Partial(),), - desired_output_layouts=(Shard(1),), - ), - # replicate computation for the router - "moe.router.gate": NoParallel(), + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh # input Replicate, output Partial - "moe.experts": TensorParallel(output_layout=Partial()), - "moe.shared_expert": TensorParallel(output_layout=Partial()), - } + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) parallelize_module( - module=transformer_block, - device_mesh=tp_mesh, - parallelize_plan=moe_layer_plan, + module=transformer_block.moe.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, ) diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index a07bf0f7b..2195617e6 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -8,6 +8,8 @@ import torch.nn.functional as F from torch import nn +from ..infra.expert_parallel import expert_parallel + from .args import TransformerModelArgs @@ -29,49 +31,73 @@ def __init__( def forward( self, x: torch.Tensor, - num_local_tokens_per_expert: torch.Tensor | list[int] | None = None, + num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: - # TODO: keeping this for loop implementation for comparison - # and readability, will remove later - if not self.use_grouped_mm: - if num_local_tokens_per_expert is not None: - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x, - split_size_or_sections=num_local_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - w1, w2, w3 = ( - self.w1[expert_idx], - self.w2[expert_idx], - self.w3[expert_idx], - ) - h = F.silu(torch.matmul(x_expert, w1)) - h = h * torch.matmul(x_expert, w3) - h = torch.matmul(h, w2) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, self.w1)) - h = h * torch.bmm(x, self.w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, self.w2) - - return out - - # grouped mm implementation - if num_local_tokens_per_expert is not None: - # https://github.com/pytorch/pytorch/pull/150374 - # NOTE: torch._gouped_mm requires bf16 dtypes - # and shapes to be multiple of 8 - offsets = torch.cumsum( - num_local_tokens_per_expert, dim=0, dtype=torch.int32 + if self.use_grouped_mm: + return GroupedExperts._run_experts_grouped_mm( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + else: + return GroupedExperts._run_experts_for_loop( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + @expert_parallel + @staticmethod + def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + + @expert_parallel + @staticmethod + def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # grouped mm between a 2D tensor and a 3D tensor assert x.dim() == 2 else: @@ -80,11 +106,12 @@ def forward( assert x.dim() == 3 assert ( - x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 + x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16 ), "torch._grouped_mm only supports bf16 dtypes" - h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) - h = h * torch._grouped_mm(x, self.w3, offs=offsets) - out = torch._grouped_mm(h, self.w2, offs=offsets) + + h = F.silu(torch._grouped_mm(x, w1, offs=offsets)) + h = h * torch._grouped_mm(x, w3, offs=offsets) + out = torch._grouped_mm(h, w2, offs=offsets) return out @@ -120,7 +147,7 @@ def __init__( self.use_sigmoid = use_sigmoid def forward( - self, x: torch.Tensor, expert_bias: torch.Tensor = None + self, x: torch.Tensor, expert_bias: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -131,7 +158,7 @@ def forward( Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. token_indices (torch.Tensor): Token indices for routed_input with shape ``(bs*slen*top_k,)``. - num_local_tokens_per_expert (torch.Tensor): + num_tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert with shape ``(num_experts,)``. """ # scores shape (bs*slen, num_experts) @@ -146,13 +173,18 @@ def forward( # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value # top_scores is still derived from the original scores. - _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 - ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) # group tokens together by expert indices from 0 to num_experts and pass that to experts forward - num_local_tokens_per_expert = torch.histc( + num_tokens_per_expert = torch.histc( selected_experts_indices.view(-1), bins=self.num_experts, min=0, @@ -165,7 +197,7 @@ def forward( top_scores = top_scores.view(-1)[token_indices_experts_sorted] token_indices_experts_sorted = token_indices_experts_sorted // self.top_k - return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert + return top_scores, token_indices_experts_sorted, num_tokens_per_expert def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -191,12 +223,11 @@ def __init__(self, model_args: TransformerModelArgs): hidden_dim = int(hidden_dim / hidden_dim_denom) hidden_dim += -hidden_dim % model_args.multiple_of - self.use_grouped_mm = model_args.use_grouped_mm self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, num_experts=num_experts, - use_grouped_mm=self.use_grouped_mm, + use_grouped_mm=model_args.use_grouped_mm, ) self.router = TokenChoiceTopKRouter( dim=dim, num_experts=num_experts, top_k=model_args.top_k @@ -206,40 +237,31 @@ def __init__(self, model_args: TransformerModelArgs): dim=dim, hidden_dim=hidden_dim, num_experts=1, - use_grouped_mm=self.use_grouped_mm, + use_grouped_mm=model_args.use_grouped_mm, ) if model_args.use_shared_expert else None ) - # auxiliary-loss-free load balancing + # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) + # NOTE: tokens_per_expert is accumulated in the model forward pass. + # expert_bias is updated outside the model in an optimzer step pre hook + # to work with gradient accumulation. self.load_balance_coeff = model_args.load_balance_coeff - # the fields below are defined even when load_balance_coeff is None - # to make initialization and checkpointing code simpler - self.register_buffer( - "expert_bias", - torch.zeros(num_experts, dtype=torch.float32), - persistent=True, - ) - self.register_buffer( - "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), - persistent=True, - ) - - # NOTE: forward hook, forward pre hook, or backward pre hook - # would conflict with activation checkpointing - if self.load_balance_coeff is not None and self.load_balance_coeff > 0: - self.register_full_backward_hook(self._update_expert_bias) - - def _update_expert_bias(self, *_): - expert_bias_delta = self.load_balance_coeff * torch.sign( - self.tokens_per_expert.mean() - self.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - self.expert_bias.add_(expert_bias_delta) - - self.tokens_per_expert.zero_() + if self.load_balance_coeff is not None: + assert self.load_balance_coeff > 0.0 + self.register_buffer( + "expert_bias", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + else: + self.expert_bias = None def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -252,15 +274,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape # top_scores and selected_indices shape (bs*slen*top_k,) - # num_local_tokens_per_expert shape (num_experts,) + # num_tokens_per_expert shape (num_experts,) ( top_scores, token_indices, - num_local_tokens_per_expert, + num_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - # will be used to update the expert bias for load balancing - self.tokens_per_expert += num_local_tokens_per_expert + # tokens_per_expert will be used to update the expert bias for load balancing. + # Prevent extra local tokens accumulation on evaluation or activation recomputation. + if self.load_balance_coeff is not None and torch.is_grad_enabled(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) # shape (bs*slen*top_k, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) @@ -275,41 +300,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x.dtype ) - if self.use_grouped_mm: - # NOTE: In order to use torch._grouped_mm, we need to make sure - # the number of tokens each expert gets is a multiple of 16. - # The following kernel helps achieve this via padding, without - # incurring synchronization between device and host. - from torchtitan.experiments.kernels.moe.indices import ( - generate_permute_indices, - ) - - ALIGN_SIZE_M = 16 - - with torch.no_grad(): - ( - permuted_indices, - num_local_tokens_per_expert, - _, - ) = generate_permute_indices( - num_local_tokens_per_expert, - self.experts.num_experts, - 1, - token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, - ALIGN_SIZE_M, - ) - token_indices = torch.vstack( - (token_indices, token_indices.new_zeros((dim))) - ) - token_indices = token_indices[permuted_indices, :] - routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) - routed_input = routed_input[permuted_indices, :] - else: - # NOTE: this would incur a synchronization between device and host - num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_local_tokens_per_expert) + routed_output = self.experts(routed_input, num_tokens_per_expert) # shared expert if self.shared_expert is not None: @@ -333,10 +325,11 @@ def init_weights( if self.shared_expert is not None: self.shared_expert.init_weights(init_std) - with torch.device(buffer_device): - self.expert_bias = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) + if self.load_balance_coeff is not None: + with torch.device(buffer_device): + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py new file mode 100644 index 000000000..d4829de88 --- /dev/null +++ b/torchtitan/experiments/llama4/optimizer.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.components.ft import FTManager +from torchtitan.components.optimizer import build_optimizers, OptimizersContainer +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + + +# for MoE auxiliary-loss-free load balancing +def _update_expert_bias( + model_parts: list[nn.Module], + world_mesh: dict[str, DeviceMesh], + parallel_dims: ParallelDims, +): + dp_cp_mesh = world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + # TODO: Currently this sync is blocking (thus exposed) and happens on the + # default compute stream. Need to assess if this is OK performance-wise. + for model_part in model_parts: + for transformer_block in model_part.layers.values(): + if transformer_block.moe_enabled: + moe = transformer_block.moe + if moe.load_balance_coeff is None: + return + + if dp_cp_mesh is not None: + torch.distributed.all_reduce( + moe.tokens_per_expert, group=dp_cp_mesh.get_group() + ) + + with torch.no_grad(): + expert_bias_delta = moe.load_balance_coeff * torch.sign( + moe.tokens_per_expert.mean() - moe.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() + + +def build_llama4_optimizers( + model_parts: list[nn.Module], + job_config: JobConfig, + parallel_dims: ParallelDims, + world_mesh: DeviceMesh, + ft_manager: FTManager, +) -> OptimizersContainer: + optimizers = build_optimizers( + model_parts=model_parts, + job_config=job_config, + parallel_dims=parallel_dims, + world_mesh=world_mesh, + ft_manager=ft_manager, + ) + + optimizers.register_step_pre_hook( + lambda *args, **kwargs: _update_expert_bias( + model_parts, world_mesh=world_mesh, parallel_dims=parallel_dims + ) + ) + + return optimizers diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 41a8459ea..6dce7d5b7 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -26,7 +26,8 @@ tokenizer_path = "./tests/assets/test_tiktoken.model" # converters = ["float8"] [optimizer] -name = "AdamW" +# TODO: AdamW has numerical issues when TP is used, need to fix it +name = "Adam" lr = 4e-3 eps = 1e-15 @@ -52,6 +53,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 +expert_parallel_degree = 1 [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 90a98f9b0..df395adcb 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -90,9 +90,8 @@ def parallelize_llama( if job_config.training.compile: apply_compile(model) - if ( - parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled - ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index ddb6961e5..aeff047b5 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader @@ -23,6 +24,7 @@ from torchtitan.components.optimizer import OptimizersContainer from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims DeviceType = int | str | torch.device @@ -71,7 +73,8 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: TokenizerBuilder: TypeAlias = Callable[..., Tokenizer] MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] OptimizersBuilder: TypeAlias = Callable[ - [list[nn.Module], JobConfig, FTManager], OptimizersContainer + [list[nn.Module], JobConfig, ParallelDims, DeviceMesh, FTManager], + OptimizersContainer, ] LRSchedulersBuilder: TypeAlias = Callable[ [OptimizersContainer, JobConfig], LRSchedulersContainer diff --git a/torchtitan/train.py b/torchtitan/train.py index ca1480f2e..e6a1ffa7d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -89,6 +89,7 @@ def __init__(self, job_config: JobConfig): cp=parallelism_config.context_parallel_degree, tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, world_size=world_size, enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) @@ -280,7 +281,7 @@ def __init__(self, job_config: JobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config, self.ft_manager + self.model_parts, job_config, parallel_dims, world_mesh, self.ft_manager ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( self.optimizers, job_config @@ -436,6 +437,7 @@ def train_step( self.job_config.training.max_norm, foreach=True, pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, + parallel_dims=parallel_dims, ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() From 70ad37b9c9e7234efd06ec4df707da468e3c9570 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 25 Jun 2025 15:08:35 -0700 Subject: [PATCH 02/12] add tp test --- .../models/deepseek_v3/infra/parallelize.py | 118 ++++++++++++++++++ torchtitan/models/deepseek_v3/model/model.py | 29 +++-- torchtitan/models/deepseek_v3/model/moe.py | 5 +- .../train_configs/deepseek_v3_16b.toml | 2 + 4 files changed, 141 insertions(+), 13 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 99338663f..40b449311 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -6,9 +6,19 @@ import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel +from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp from torchtitan.tools.logging import logger @@ -19,6 +29,40 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): + + if parallel_dims.tp_enabled: + if job_config.parallelism.enable_async_tensor_parallel: + raise NotImplementedError( + "Currently, async TP is not supported for deepseekv3" + ) + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if enable_float8_tensorwise_tp: + raise NotImplementedError( + "Currently, float8 tensorwise TP is not supported for deepseekv3" + ) + + if parallel_dims.loss_parallel_enabled: + raise NotImplementedError( + "Currently, loss parallel is not supported for deepseekv3" + ) + + apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + + apply_moe_tp(model, world_mesh["tp"]) + if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -48,3 +92,77 @@ def parallelize_deepseekv3( logger.info("Applied FSDP to the model") return model + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wkv_a": NoParallel(), # Make ths a DTensor + "attention.wkv_b": colwise_parallel(), + "attention.wq_a": NoParallel(), + "attention.wq_b": colwise_parallel(), + "attention.wq": colwise_parallel(), # This is only used when q_lora_rank==0 + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3eb0f2fbc..fbfe5a9d7 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import math +from re import I from typing import Tuple import torch @@ -194,7 +195,10 @@ def forward( else: q = self.wq_b(self.q_norm(self.wq_a(x))) - q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of q and kv as TP may have sharded them after + # the above linear ops. + q = q.view(bsz, seqlen, -1, self.qk_head_dim) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) @@ -211,10 +215,11 @@ def forward( kv = self.wkv_b( self.kv_norm(kv) ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) - kv = kv.view(bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + n_local_heads = k_nope.size(2) k = torch.cat( - [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 + [k_nope, k_pe.expand(-1, -1, n_local_heads, -1)], dim=-1 ) # (bsz, seqlen, n_heads, qk_head_dim) q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) @@ -278,12 +283,13 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): super().__init__() self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe = ( - FeedForward(model_args.dim, model_args.inter_dim) - if layer_id < model_args.n_dense_layers - else MoE(model_args) - ) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.moe_enabled = layer_id < model_args.n_dense_layers + + if self.moe_enabled: + self.moe = MoE(model_args) + else: + self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): """ @@ -297,7 +303,10 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): torch.Tensor: Output tensor with the same shape as the input. """ x = x + self.attention(self.attention_norm(x), freqs_cis) - x = x + self.moe(self.moe_norm(x)) + if self.moe_enabled: + x = x + self.moe(self.ffn_norm(x)) + else: + x = x + self.feed_forward(self.ffn_norm(x)) return x diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index c9217c8be..eebe09aa2 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -116,8 +116,7 @@ def __init__( self.top_k = top_k self.use_sigmoid = use_sigmoid self.route_sclaing_factor = route_sclaing_factor - - self.weight = nn.Parameter(torch.empty((self.num_experts, self.dim))) + self.gate = nn.Linear(self.dim, self.num_experts, bias=False) def forward( self, x: torch.Tensor, expert_bias: torch.Tensor = None @@ -138,7 +137,7 @@ def forward( Number of tokens assigned to each expert with shape ``(num_experts,)``. """ # scores shape (bs*slen, num_experts) - scores = F.linear(x, self.weight, bias=None) + scores = self.gate(x) # By default, sigmoid or softmax is performed in float32 to avoid loss explosion if self.use_sigmoid: diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 4f08fb098..60048a56d 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -49,6 +49,8 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 2 +disable_loss_parallel = true [checkpoint] enable_checkpoint = false From 9a4747fede3d109a32f580b5534d492cf392ae3d Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 25 Jun 2025 16:47:31 -0700 Subject: [PATCH 03/12] add TP for norm --- torchtitan/models/deepseek_v3/infra/parallelize.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 40b449311..bbd53b789 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -33,7 +33,7 @@ def parallelize_deepseekv3( if parallel_dims.tp_enabled: if job_config.parallelism.enable_async_tensor_parallel: raise NotImplementedError( - "Currently, async TP is not supported for deepseekv3" + "Currently, async TP is not tested for deepseekv3" ) enable_float8_linear = "float8" in job_config.model.converters @@ -45,12 +45,12 @@ def parallelize_deepseekv3( enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise if enable_float8_tensorwise_tp: raise NotImplementedError( - "Currently, float8 tensorwise TP is not supported for deepseekv3" + "Currently, float8 tensorwise TP is not tested for deepseekv3" ) if parallel_dims.loss_parallel_enabled: raise NotImplementedError( - "Currently, loss parallel is not supported for deepseekv3" + "Currently, loss parallel is not tested for deepseekv3" ) apply_tp( @@ -140,10 +140,12 @@ def apply_tp( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), - "attention.wkv_a": NoParallel(), # Make ths a DTensor + "attention.wkv_a": NoParallel(), "attention.wkv_b": colwise_parallel(), + "attention.kv_norm": NoParallel(), "attention.wq_a": NoParallel(), "attention.wq_b": colwise_parallel(), + "attention.q_norm": NoParallel(), "attention.wq": colwise_parallel(), # This is only used when q_lora_rank==0 "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), From 45649de1860a58936c7d4a1ff3b3d743d6db4574 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 25 Jun 2025 17:36:44 -0700 Subject: [PATCH 04/12] add TP v1 --- torchtitan/models/deepseek_v3/infra/parallelize.py | 10 ++++------ torchtitan/models/deepseek_v3/model/model.py | 1 - .../models/deepseek_v3/train_configs/debug_model.toml | 2 ++ .../deepseek_v3/train_configs/deepseek_v3_16b.toml | 4 ++-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index bbd53b789..6e91b76f7 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -32,8 +32,10 @@ def parallelize_deepseekv3( if parallel_dims.tp_enabled: if job_config.parallelism.enable_async_tensor_parallel: + # TODO(jianiw): This branch needs to be tested and enabled raise NotImplementedError( - "Currently, async TP is not tested for deepseekv3" + "Currently, async TP is not tested for deepseekv3. \ + torch.compile is not supported yet, which is required for async TP." ) enable_float8_linear = "float8" in job_config.model.converters @@ -44,15 +46,11 @@ def parallelize_deepseekv3( enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise if enable_float8_tensorwise_tp: + # TODO(jianiw): This branch needs to be tested and enabled raise NotImplementedError( "Currently, float8 tensorwise TP is not tested for deepseekv3" ) - if parallel_dims.loss_parallel_enabled: - raise NotImplementedError( - "Currently, loss parallel is not tested for deepseekv3" - ) - apply_tp( model, world_mesh["tp"], diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index fbfe5a9d7..2db4163a5 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import math -from re import I from typing import Tuple import torch diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index eddca8849..66db16f9b 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -50,6 +50,8 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 60048a56d..03f8232af 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -49,8 +49,8 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 -disable_loss_parallel = true +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false From 850ddadb08f6852113a48278811fcb2408b489e9 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 26 Jun 2025 15:34:34 -0700 Subject: [PATCH 05/12] fix bug --- torchtitan/models/deepseek_v3/model/model.py | 91 +++++++++++++++---- torchtitan/models/deepseek_v3/model/moe.py | 9 +- .../train_configs/debug_model.toml | 2 +- .../train_configs/deepseek_v3_16b.toml | 4 +- 4 files changed, 86 insertions(+), 20 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 2db4163a5..0fde3b965 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -152,17 +152,23 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.v_head_dim = model_args.v_head_dim if self.q_lora_rank == 0: - self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim) + self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim, bias=False) else: - self.wq_a = nn.Linear(self.dim, self.q_lora_rank) + self.wq_a = nn.Linear(self.dim, self.q_lora_rank, bias=False) self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps) - self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim) - self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) + self.wq_b = nn.Linear( + self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False + ) + self.wkv_a = nn.Linear( + self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False + ) self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps) self.wkv_b = nn.Linear( - self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim) + self.kv_lora_rank, + self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, ) - self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim) + self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim, bias=False) self.softmax_scale = self.qk_head_dim**-0.5 if model_args.max_seq_len > model_args.original_seq_len: @@ -192,8 +198,8 @@ def forward( if self.q_lora_rank == 0: q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim) else: - q = self.wq_b(self.q_norm(self.wq_a(x))) - + q = self.wq_a(x) + q = self.wq_b(self.q_norm(q)) # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual # local heads from sizes of q and kv as TP may have sharded them after # the above linear ops. @@ -235,6 +241,24 @@ def forward( output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) return self.wo(output) # (bsz, seqlen, dim) + def init_weights(self, init_std: float): + linear_list = [ + self.wkv_a, + self.wkv_b, + ] + if self.q_lora_rank > 0: + linear_list.extend([self.wq_a, self.wq_b]) + else: + linear_list.append(self.wq) + + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + self.kv_norm.reset_parameters() + if self.q_lora_rank > 0: + self.q_norm.reset_parameters() + class FeedForward(nn.Module): """ @@ -266,7 +290,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) - def init_weights(self, init_std: float): + def init_weights(self, init_std: float = 0.02): nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) for linear in (self.w2, self.w3): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) @@ -283,13 +307,16 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe_enabled = layer_id < model_args.n_dense_layers + self.moe_enabled = layer_id >= model_args.n_dense_layers if self.moe_enabled: self.moe = MoE(model_args) else: self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) + # TODO: Need to revisit the weight initialization for the TransformerBlock + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): """ Forward pass for the Transformer block. @@ -308,6 +335,15 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): x = x + self.feed_forward(self.ffn_norm(x)) return x + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + if self.moe_enabled: + self.moe.init_weights(self.weight_init_std, buffer_device) + else: + self.feed_forward.init_weights(self.weight_init_std) + class DeepSeekV3Model(nn.Module, ModelProtocol): """ @@ -319,7 +355,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.max_seq_len = model_args.max_seq_len self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) self.register_buffer( - "freqs_cis", precompute_freqs_cis(model_args), persistent=False + "freqs_cis", precompute_freqs_cis(model_args), persistent=True ) self.layers = torch.nn.ModuleDict() @@ -328,10 +364,36 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.norm = nn.RMSNorm(model_args.dim) self.output = nn.Linear( - model_args.dim, model_args.vocab_size, dtype=torch.get_default_dtype() + model_args.dim, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, ) + self.model_args = model_args self.init_weights() + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = precompute_freqs_cis(self.model_args) + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + def forward(self, tokens: torch.Tensor): """ Forward pass for the Transformer model. @@ -347,8 +409,5 @@ def forward(self, tokens: torch.Tensor): for layer in self.layers.values(): h = layer(h, self.freqs_cis) h = self.norm(h) - output = self.output(h) # (batch_size, seq_len, dim) + output = self.output(h) return output - - def init_weights(self, buffer_device: torch.device | None = None) -> None: - pass diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index eebe09aa2..8ea1fc875 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -89,6 +89,11 @@ def forward( return out + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) + class TokenChoiceTopKRouter(nn.Module): """This class implements token-choice routing. In token-choice top-K routing, each token is @@ -173,6 +178,9 @@ def forward( return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + class MoE(nn.Module): def __init__(self, model_args: DeepSeekV3ModelArgs): @@ -231,7 +239,6 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): if self.load_balance_coeff is not None and self.load_balance_coeff > 0: self.register_full_backward_hook(self._update_expert_bias) - # TODO: double check the bias update logic. It aligns with the paper. def _update_expert_bias(self, *_): expert_bias_delta = self.load_balance_coeff * torch.sign( self.tokens_per_expert.mean() - self.tokens_per_expert diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 66db16f9b..0f9d4e74f 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -50,7 +50,7 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 enable_async_tensor_parallel = false [checkpoint] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 03f8232af..fc169fa32 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -38,7 +38,7 @@ decay_type = "linear" lr_min = 0.0 [training] -local_batch_size = 32 +local_batch_size = 16 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 @@ -49,7 +49,7 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 enable_async_tensor_parallel = false [checkpoint] From 9ae2eab4549d5d39bbf2d4ac63ab00fac68581f6 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 27 Jun 2025 15:40:42 -0700 Subject: [PATCH 06/12] test --- torchtitan/models/deepseek_v3/__init__.py | 10 +-- torchtitan/models/deepseek_v3/model/args.py | 4 +- torchtitan/models/deepseek_v3/model/model.py | 42 ++---------- torchtitan/models/deepseek_v3/model/moe.py | 66 ++++++++++++++++++- .../train_configs/deepseek_v3_16b.toml | 2 +- torchtitan/train.py | 2 +- 6 files changed, 77 insertions(+), 49 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 7eb16a1f3..da647e396 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -32,12 +32,12 @@ dim=256, inter_dim=10944, moe_inter_dim=1408, - n_layers=3, - n_dense_layers=1, + n_layers=1, + n_dense_layers=0, # no FFN layer, all MoE layers n_heads=16, - n_routed_experts=8, - n_shared_experts=2, - n_activated_experts=3, + n_routed_experts=2, # hang only happens when n_routed_experts > n_activated_experts + n_shared_experts=1, + n_activated_experts=1, route_scale=1.0, q_lora_rank=0, kv_lora_rank=512, diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 09e882764..769268999 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -75,8 +75,8 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_limited_groups: int = 1 score_func: Literal["softmax", "sigmoid"] = "softmax" route_scale: float = 1.0 - use_grouped_mm: bool = False - load_balance_coeff: float | None = 1e-3 + use_grouped_mm: bool = True + load_balance_coeff: float = 1e-3 # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 kv_lora_rank: int = 512 diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 0fde3b965..59573b98b 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -14,7 +14,7 @@ from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs -from .moe import MoE +from .moe import FeedForward, MoE # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 @@ -260,42 +260,6 @@ def init_weights(self, init_std: float): self.q_norm.reset_parameters() -class FeedForward(nn.Module): - """ - FeedForward module - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - class TransformerBlock(nn.Module): """ Transformer block with attention and feed-forward layers. @@ -316,6 +280,7 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): # TODO: Need to revisit the weight initialization for the TransformerBlock self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): """ @@ -330,8 +295,10 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): """ x = x + self.attention(self.attention_norm(x), freqs_cis) if self.moe_enabled: + print(f"In TransformerBlock {self.layer_id}: MoE is enabled") x = x + self.moe(self.ffn_norm(x)) else: + print(f"In TransformerBlock {self.layer_id}: FFN is enabled") x = x + self.feed_forward(self.ffn_norm(x)) return x @@ -360,6 +327,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): + print(f"Create layer: {layer_id}") self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) self.norm = nn.RMSNorm(model_args.dim) diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 8ea1fc875..846903cb8 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -11,6 +11,42 @@ from .args import DeepSeekV3ModelArgs +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float = 0.02): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + # Reference: torchtitan/experiments/llama4/model/ class GroupedExperts(nn.Module): def __init__( @@ -212,11 +248,17 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): GroupedExperts( dim=dim, hidden_dim=hidden_dim * model_args.n_shared_experts, - num_experts=1, + num_experts=1, # Here needs to be 1 to make it equivalent to the MLP use_grouped_mm=self.use_grouped_mm, ) if model_args.n_shared_experts > 0 else None + # FeedForward( + # dim=dim, + # hidden_dim=hidden_dim * model_args.n_shared_experts, + # ) + # if model_args.n_shared_experts > 0 + # else None ) # auxiliary-loss-free load balancing @@ -266,6 +308,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + print( + "In MoE, top_scores shape: ", + top_scores.shape, + "token_indices: ", + token_indices.shape, + "num_local_tokens: ", + num_local_tokens_per_expert.shape, + ) + # will be used to update the expert bias for load balancing self.tokens_per_expert += num_local_tokens_per_expert @@ -299,6 +350,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert, self.experts.num_experts, 1, + token_indices[0] + self.experts.num_experts * ALIGN_SIZE_M, ALIGN_SIZE_M, ) token_indices = torch.vstack( @@ -311,8 +363,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE: this would incur a synchronization between device and host num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() + print("Num local tokens per expert: ", num_local_tokens_per_expert) # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_local_tokens_per_expert) + routed_output = self.experts( + routed_input, num_local_tokens_per_expert + ) # torch.Size([16384(bsz), 256]) + print("Routed output shape: ", routed_output.shape) routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( x.dtype ) @@ -321,10 +377,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.shared_expert is not None: out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( bs * slen, dim - ) + ) # torch.Size([16384, 256]) None else: out = torch.zeros_like(x.reshape(bs * slen, dim)) + print( + "Out shape: ", out.shape, out.grad.shape if out.grad is not None else None + ) + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) out = out.reshape(bs, slen, dim) return out diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index fc169fa32..19267f036 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -49,7 +49,7 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false [checkpoint] diff --git a/torchtitan/train.py b/torchtitan/train.py index e6a1ffa7d..08f55e8dd 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,10 +11,10 @@ from typing import Any, Generator, Iterable, Optional import torch -from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration from torchtitan.components.loss import rescale_accumulated_loss From 9306d80ccd0e944035aebd40c06c68a82e84c8a6 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Sat, 28 Jun 2025 10:27:32 -0700 Subject: [PATCH 07/12] tp on groupped_mm finished --- torchtitan/experiments/kernels/moe/indices.py | 1 + torchtitan/models/deepseek_v3/model/model.py | 3 - torchtitan/models/deepseek_v3/model/moe.py | 58 +++++++++---------- .../train_configs/debug_model.toml | 1 + .../train_configs/deepseek_v3_16b.toml | 4 +- 5 files changed, 33 insertions(+), 34 deletions(-) diff --git a/torchtitan/experiments/kernels/moe/indices.py b/torchtitan/experiments/kernels/moe/indices.py index 30f7d98c2..cf2e006e7 100644 --- a/torchtitan/experiments/kernels/moe/indices.py +++ b/torchtitan/experiments/kernels/moe/indices.py @@ -77,6 +77,7 @@ def fill_indices_wrapper( max_blocks: int = 1024, # cap on total number of blocks to launch ): # preallocate output + print("max_len: ", max_len, "block_size: ", block_size, "max_blocks: ", max_blocks) permuted_indices = torch.full( (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device ) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 59573b98b..10454edf2 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -295,10 +295,8 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): """ x = x + self.attention(self.attention_norm(x), freqs_cis) if self.moe_enabled: - print(f"In TransformerBlock {self.layer_id}: MoE is enabled") x = x + self.moe(self.ffn_norm(x)) else: - print(f"In TransformerBlock {self.layer_id}: FFN is enabled") x = x + self.feed_forward(self.ffn_norm(x)) return x @@ -327,7 +325,6 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): - print(f"Create layer: {layer_id}") self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) self.norm = nn.RMSNorm(model_args.dim) diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 846903cb8..fa58900cf 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -211,7 +211,7 @@ def forward( top_scores = ( top_scores * self.route_sclaing_factor ) # must multiply the scaling factor - + print("In TokenChoiceTopKRouter, top_scores shape: ", top_scores) return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert def init_weights(self, init_std: float): @@ -253,12 +253,6 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): ) if model_args.n_shared_experts > 0 else None - # FeedForward( - # dim=dim, - # hidden_dim=hidden_dim * model_args.n_shared_experts, - # ) - # if model_args.n_shared_experts > 0 - # else None ) # auxiliary-loss-free load balancing @@ -298,6 +292,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ + print("In MoE input, x shape: ", x) bs, slen, dim = x.shape # top_scores and selected_indices shape (bs*slen*top_k,) @@ -308,14 +303,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - print( - "In MoE, top_scores shape: ", - top_scores.shape, - "token_indices: ", - token_indices.shape, - "num_local_tokens: ", - num_local_tokens_per_expert.shape, - ) + # print( + # "In MoE, top_scores shape: ", + # top_scores.shape, + # "token_indices: ", + # token_indices.shape, + # "num_local_tokens: ", + # num_local_tokens_per_expert.shape, + # ) # will be used to update the expert bias for load balancing self.tokens_per_expert += num_local_tokens_per_expert @@ -329,6 +324,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices, ) + print("Routed input: ", routed_input) + + # TODO: remove this line, this is a temporary test + routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to( + x.dtype + ) if self.use_grouped_mm: # NOTE: In order to use torch._grouped_mm, we need to make sure @@ -350,28 +351,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert, self.experts.num_experts, 1, - token_indices[0] + self.experts.num_experts * ALIGN_SIZE_M, + token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, ALIGN_SIZE_M, ) - token_indices = torch.vstack( - (token_indices, token_indices.new_zeros((dim))) - ) - token_indices = token_indices[permuted_indices, :] + routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) + input_shape = routed_input.shape routed_input = routed_input[permuted_indices, :] else: # NOTE: this would incur a synchronization between device and host num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() + input_shape, permuted_indices = None, None - print("Num local tokens per expert: ", num_local_tokens_per_expert) # shape (bs*slen*top_k, dim) routed_output = self.experts( routed_input, num_local_tokens_per_expert ) # torch.Size([16384(bsz), 256]) - print("Routed output shape: ", routed_output.shape) - routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( - x.dtype - ) + + routed_output_unpermuted = routed_output.new_empty(input_shape) + routed_output_unpermuted[permuted_indices, :] = routed_output + routed_output = routed_output_unpermuted[:-1] + + # TODO: Use this line instead if routed_input*top_scores, need to pad top_scores to be multiple of 16 + # routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + # x.dtype + # ) # shared expert if self.shared_expert is not None: @@ -381,10 +385,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: out = torch.zeros_like(x.reshape(bs * slen, dim)) - print( - "Out shape: ", out.shape, out.grad.shape if out.grad is not None else None - ) - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) out = out.reshape(bs, slen, dim) return out diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 0f9d4e74f..80566b04d 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -45,6 +45,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +seed = 0 [parallelism] data_parallel_replicate_degree = 1 diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 19267f036..64782a887 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -41,7 +41,7 @@ lr_min = 0.0 local_batch_size = 16 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 2 compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) @@ -49,7 +49,7 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 enable_async_tensor_parallel = false [checkpoint] From 6ceff837486267b0bd88a90a9202caff2080bf1d Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 1 Jul 2025 13:08:30 -0700 Subject: [PATCH 08/12] TP gemm works --- torchtitan/experiments/kernels/moe/indices.py | 1 - torchtitan/models/deepseek_v3/__init__.py | 10 ++--- torchtitan/models/deepseek_v3/model/args.py | 2 +- torchtitan/models/deepseek_v3/model/moe.py | 45 +++++++------------ .../train_configs/deepseek_v3_16b.toml | 7 +-- 5 files changed, 26 insertions(+), 39 deletions(-) diff --git a/torchtitan/experiments/kernels/moe/indices.py b/torchtitan/experiments/kernels/moe/indices.py index cf2e006e7..30f7d98c2 100644 --- a/torchtitan/experiments/kernels/moe/indices.py +++ b/torchtitan/experiments/kernels/moe/indices.py @@ -77,7 +77,6 @@ def fill_indices_wrapper( max_blocks: int = 1024, # cap on total number of blocks to launch ): # preallocate output - print("max_len: ", max_len, "block_size: ", block_size, "max_blocks: ", max_blocks) permuted_indices = torch.full( (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device ) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index da647e396..7eb16a1f3 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -32,12 +32,12 @@ dim=256, inter_dim=10944, moe_inter_dim=1408, - n_layers=1, - n_dense_layers=0, # no FFN layer, all MoE layers + n_layers=3, + n_dense_layers=1, n_heads=16, - n_routed_experts=2, # hang only happens when n_routed_experts > n_activated_experts - n_shared_experts=1, - n_activated_experts=1, + n_routed_experts=8, + n_shared_experts=2, + n_activated_experts=3, route_scale=1.0, q_lora_rank=0, kv_lora_rank=512, diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 769268999..51288000f 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -75,7 +75,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_limited_groups: int = 1 score_func: Literal["softmax", "sigmoid"] = "softmax" route_scale: float = 1.0 - use_grouped_mm: bool = True + use_grouped_mm: bool = False load_balance_coeff: float = 1e-3 # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index fa58900cf..16bad16ce 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -201,17 +201,20 @@ def forward( min=0, max=self.num_experts, ) + + # Reorder the token indices to match the order of the experts # token_indices_experts_sorted shape (bs*slen*top_k,) token_indices_experts_sorted = torch.argsort( selected_experts_indices.view(-1), stable=True ) + + # reorder the scores to match the order of the token indices top_scores = top_scores.view(-1)[token_indices_experts_sorted] token_indices_experts_sorted = token_indices_experts_sorted // self.top_k top_scores = ( top_scores * self.route_sclaing_factor ) # must multiply the scaling factor - print("In TokenChoiceTopKRouter, top_scores shape: ", top_scores) return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert def init_weights(self, init_std: float): @@ -292,7 +295,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - print("In MoE input, x shape: ", x) bs, slen, dim = x.shape # top_scores and selected_indices shape (bs*slen*top_k,) @@ -303,15 +305,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - # print( - # "In MoE, top_scores shape: ", - # top_scores.shape, - # "token_indices: ", - # token_indices.shape, - # "num_local_tokens: ", - # num_local_tokens_per_expert.shape, - # ) - # will be used to update the expert bias for load balancing self.tokens_per_expert += num_local_tokens_per_expert @@ -324,12 +317,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices, ) - print("Routed input: ", routed_input) - - # TODO: remove this line, this is a temporary test - routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to( - x.dtype - ) if self.use_grouped_mm: # NOTE: In order to use torch._grouped_mm, we need to make sure @@ -361,30 +348,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: # NOTE: this would incur a synchronization between device and host num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() - input_shape, permuted_indices = None, None + permuted_indices, input_shape = None, None # shape (bs*slen*top_k, dim) - routed_output = self.experts( - routed_input, num_local_tokens_per_expert - ) # torch.Size([16384(bsz), 256]) + routed_output = self.experts(routed_input, num_local_tokens_per_expert) - routed_output_unpermuted = routed_output.new_empty(input_shape) - routed_output_unpermuted[permuted_indices, :] = routed_output - routed_output = routed_output_unpermuted[:-1] + if self.use_grouped_mm: + # NOTE: Reverese the permutation to get the original order as inputs + routed_output_unpermuted = routed_output.new_empty(input_shape) + routed_output_unpermuted[permuted_indices, :] = routed_output + routed_output = routed_output_unpermuted[:-1] # remove padding - # TODO: Use this line instead if routed_input*top_scores, need to pad top_scores to be multiple of 16 - # routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( - # x.dtype - # ) + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + x.dtype + ) # shared expert if self.shared_expert is not None: out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( bs * slen, dim - ) # torch.Size([16384, 256]) None + ) else: out = torch.zeros_like(x.reshape(bs * slen, dim)) + # Accumulate multiple expert results becase each token can be routed to multiple experts out = out.scatter_add(dim=0, index=token_indices, src=routed_output) out = out.reshape(bs, slen, dim) return out diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 64782a887..8e841628e 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -15,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 disable_color_printing = false -enable_tensorboard = false +enable_tensorboard = true save_tb_folder = "tb" enable_wandb = false @@ -41,15 +41,16 @@ lr_min = 0.0 local_batch_size = 16 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 2 +steps = 10 compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) +seed = 0 [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false [checkpoint] From 9785b50211dcaf2938910980904b05e1b75683d4 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 1 Jul 2025 14:57:34 -0700 Subject: [PATCH 09/12] rebase onto #1324 --- torchtitan/models/deepseek_v3/__init__.py | 4 +- .../models/deepseek_v3/infra/parallelize.py | 13 +- torchtitan/models/deepseek_v3/model/args.py | 2 +- torchtitan/models/deepseek_v3/model/model.py | 6 +- torchtitan/models/deepseek_v3/model/moe.py | 250 ++++++++---------- .../train_configs/debug_model.toml | 5 +- .../train_configs/deepseek_v3_16b.toml | 3 +- torchtitan/train.py | 2 +- 8 files changed, 139 insertions(+), 146 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 7eb16a1f3..918215b73 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -8,9 +8,9 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers -from torchtitan.components.optimizer import build_optimizers from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers from torchtitan.protocols.train_spec import register_train_spec, TrainSpec @@ -117,7 +117,7 @@ config=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, pipelining_fn=None, - build_optimizers_fn=build_optimizers, + build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_tiktoken_tokenizer, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 6e91b76f7..d387f1a02 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -18,7 +18,7 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel -from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp +from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp from torchtitan.tools.logging import logger @@ -59,7 +59,16 @@ def parallelize_deepseekv3( enable_async_tp=False, ) - apply_moe_tp(model, world_mesh["tp"]) + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled and parallel_dims.ep_enabled + else None + ), + ) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 51288000f..769268999 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -75,7 +75,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_limited_groups: int = 1 score_func: Literal["softmax", "sigmoid"] = "softmax" route_scale: float = 1.0 - use_grouped_mm: bool = False + use_grouped_mm: bool = True load_balance_coeff: float = 1e-3 # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 10454edf2..684ce3c5f 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -8,7 +8,6 @@ from typing import Tuple import torch -import torch.nn.functional as F from torch import nn from torchtitan.models.attention import build_attention from torchtitan.protocols.train_spec import ModelProtocol @@ -369,10 +368,15 @@ def forward(self, tokens: torch.Tensor): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ + print("Input tokens:", tokens) h = self.tok_embeddings(tokens) + print("After token embedding:", h) for layer in self.layers.values(): h = layer(h, self.freqs_cis) + print(f"After layer {layer}: ", h) h = self.norm(h) + print("After normalization:", h) output = self.output(h) + print("Output logits:", output) return output diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 16bad16ce..1e8e99cbd 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from torch import nn +from torchtitan.experiments.llama4.infra.expert_parallel import expert_parallel from .args import DeepSeekV3ModelArgs @@ -47,7 +48,6 @@ def init_weights(self, init_std: float = 0.02): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) -# Reference: torchtitan/experiments/llama4/model/ class GroupedExperts(nn.Module): def __init__( self, @@ -66,49 +66,73 @@ def __init__( def forward( self, x: torch.Tensor, - num_local_tokens_per_expert: torch.Tensor | list[int] | None = None, + num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: - # TODO: keeping this for loop implementation for comparison - # and readability, will remove later - if not self.use_grouped_mm: - if num_local_tokens_per_expert is not None: - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x, - split_size_or_sections=num_local_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - w1, w2, w3 = ( - self.w1[expert_idx], - self.w2[expert_idx], - self.w3[expert_idx], - ) - h = F.silu(torch.matmul(x_expert, w1)) - h = h * torch.matmul(x_expert, w3) - h = torch.matmul(h, w2) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, self.w1)) - h = h * torch.bmm(x, self.w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, self.w2) - - return out - - # grouped mm implementation - if num_local_tokens_per_expert is not None: - # https://github.com/pytorch/pytorch/pull/150374 - # NOTE: torch._gouped_mm requires bf16 dtypes - # and shapes to be multiple of 8 - offsets = torch.cumsum( - num_local_tokens_per_expert, dim=0, dtype=torch.int32 + if self.use_grouped_mm: + return GroupedExperts._run_experts_grouped_mm( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + else: + return GroupedExperts._run_experts_for_loop( + self.w1, self.w2, self.w3, x, num_tokens_per_expert ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + @expert_parallel + @staticmethod + def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + + @expert_parallel + @staticmethod + def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # grouped mm between a 2D tensor and a 3D tensor assert x.dim() == 2 else: @@ -117,11 +141,12 @@ def forward( assert x.dim() == 3 assert ( - x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 + x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16 ), "torch._grouped_mm only supports bf16 dtypes" - h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) - h = h * torch._grouped_mm(x, self.w3, offs=offsets) - out = torch._grouped_mm(h, self.w2, offs=offsets) + + h = F.silu(torch._grouped_mm(x, w1, offs=offsets)) + h = h * torch._grouped_mm(x, w3, offs=offsets) + out = torch._grouped_mm(h, w2, offs=offsets) return out @@ -160,7 +185,7 @@ def __init__( self.gate = nn.Linear(self.dim, self.num_experts, bias=False) def forward( - self, x: torch.Tensor, expert_bias: torch.Tensor = None + self, x: torch.Tensor, expert_bias: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ TODO: We haven't implement the group-based routing (node limit routing), @@ -174,7 +199,7 @@ def forward( Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. token_indices (torch.Tensor): Token indices for routed_input with shape ``(bs*slen*top_k,)``. - num_local_tokens_per_expert (torch.Tensor): + num_tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert with shape ``(num_experts,)``. """ # scores shape (bs*slen, num_experts) @@ -189,13 +214,18 @@ def forward( # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value # top_scores is still derived from the original scores. - _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 - ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) # group tokens together by expert indices from 0 to num_experts and pass that to experts forward - num_local_tokens_per_expert = torch.histc( + num_tokens_per_expert = torch.histc( selected_experts_indices.view(-1), bins=self.num_experts, min=0, @@ -215,7 +245,7 @@ def forward( top_scores = ( top_scores * self.route_sclaing_factor ) # must multiply the scaling factor - return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert + return top_scores, token_indices_experts_sorted, num_tokens_per_expert def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -232,12 +262,11 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): top_k = model_args.n_activated_experts route_scaling_factor = model_args.route_scale - self.use_grouped_mm = model_args.use_grouped_mm self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, num_experts=num_experts, - use_grouped_mm=self.use_grouped_mm, + use_grouped_mm=model_args.use_grouped_mm, ) self.router = TokenChoiceTopKRouter( dim=dim, @@ -252,7 +281,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): dim=dim, hidden_dim=hidden_dim * model_args.n_shared_experts, num_experts=1, # Here needs to be 1 to make it equivalent to the MLP - use_grouped_mm=self.use_grouped_mm, + use_grouped_mm=model_args.use_grouped_mm, ) if model_args.n_shared_experts > 0 else None @@ -260,32 +289,20 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): # auxiliary-loss-free load balancing self.load_balance_coeff = model_args.load_balance_coeff - # the fields below are defined even when load_balance_coeff is None - # to make initialization and checkpointing code simpler - self.register_buffer( - "expert_bias", - torch.zeros(num_experts, dtype=torch.float32), - persistent=True, - ) - self.register_buffer( - "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), - persistent=True, - ) - - # NOTE: forward hook, forward pre hook, or backward pre hook - # would conflict with activation checkpointing - if self.load_balance_coeff is not None and self.load_balance_coeff > 0: - self.register_full_backward_hook(self._update_expert_bias) - - def _update_expert_bias(self, *_): - expert_bias_delta = self.load_balance_coeff * torch.sign( - self.tokens_per_expert.mean() - self.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - self.expert_bias.add_(expert_bias_delta) - - self.tokens_per_expert.zero_() + if self.load_balance_coeff is not None: + assert self.load_balance_coeff > 0.0 + self.register_buffer( + "expert_bias", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + else: + self.expert_bias = None def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -298,16 +315,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape # top_scores and selected_indices shape (bs*slen*top_k,) - # num_local_tokens_per_expert shape (num_experts,) + # num_tokens_per_expert shape (num_experts,) ( top_scores, token_indices, - num_local_tokens_per_expert, + num_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - # will be used to update the expert bias for load balancing - self.tokens_per_expert += num_local_tokens_per_expert - + # tokens_per_expert will be used to update the expert bias for load balancing. + # Prevent extra local tokens accumulation on evaluation or activation recomputation. + if self.load_balance_coeff is not None and torch.is_grad_enabled(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) # shape (bs*slen*top_k, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) @@ -318,46 +337,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: index=token_indices, ) - if self.use_grouped_mm: - # NOTE: In order to use torch._grouped_mm, we need to make sure - # the number of tokens each expert gets is a multiple of 16. - # The following kernel helps achieve this via padding, without - # incurring synchronization between device and host. - from torchtitan.experiments.kernels.moe.indices import ( - generate_permute_indices, - ) - - ALIGN_SIZE_M = 16 - - with torch.no_grad(): - ( - permuted_indices, - num_local_tokens_per_expert, - _, - ) = generate_permute_indices( - num_local_tokens_per_expert, - self.experts.num_experts, - 1, - token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, - ALIGN_SIZE_M, - ) - - routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) - input_shape = routed_input.shape - routed_input = routed_input[permuted_indices, :] - else: - # NOTE: this would incur a synchronization between device and host - num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() - permuted_indices, input_shape = None, None - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_local_tokens_per_expert) - - if self.use_grouped_mm: - # NOTE: Reverese the permutation to get the original order as inputs - routed_output_unpermuted = routed_output.new_empty(input_shape) - routed_output_unpermuted[permuted_indices, :] = routed_output - routed_output = routed_output_unpermuted[:-1] # remove padding + routed_output = self.experts(routed_input, num_tokens_per_expert) routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( x.dtype @@ -386,10 +367,11 @@ def init_weights( if self.shared_expert is not None: self.shared_expert.init_weights(init_std) - with torch.device(buffer_device): - self.expert_bias = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) + if self.load_balance_coeff is not None: + with torch.device(buffer_device): + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 80566b04d..8a003a417 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -42,16 +42,15 @@ lr_min = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 1 compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) -seed = 0 [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false [checkpoint] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 8e841628e..19267f036 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -15,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 disable_color_printing = false -enable_tensorboard = true +enable_tensorboard = false save_tb_folder = "tb" enable_wandb = false @@ -44,7 +44,6 @@ max_norm = 1.0 # grad norm clipping steps = 10 compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) -seed = 0 [parallelism] data_parallel_replicate_degree = 1 diff --git a/torchtitan/train.py b/torchtitan/train.py index 08f55e8dd..e6a1ffa7d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,10 +11,10 @@ from typing import Any, Generator, Iterable, Optional import torch +from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module -from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration from torchtitan.components.loss import rescale_accumulated_loss From 8396a61fb666d15aa6cd87102e3fc8bcd9a93e10 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 1 Jul 2025 17:42:32 -0700 Subject: [PATCH 10/12] clean --- torchtitan/models/deepseek_v3/model/model.py | 5 ----- torchtitan/models/deepseek_v3/train_configs/debug_model.toml | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 684ce3c5f..3c2bb9e84 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -368,15 +368,10 @@ def forward(self, tokens: torch.Tensor): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ - print("Input tokens:", tokens) h = self.tok_embeddings(tokens) - print("After token embedding:", h) for layer in self.layers.values(): h = layer(h, self.freqs_cis) - print(f"After layer {layer}: ", h) h = self.norm(h) - print("After normalization:", h) output = self.output(h) - print("Output logits:", output) return output diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 8a003a417..66db16f9b 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -42,7 +42,7 @@ lr_min = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 1 +steps = 10 compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) From 90cf81efb3c31db5432d606a06a1a4dbc0aeba30 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 2 Jul 2025 13:32:13 -0700 Subject: [PATCH 11/12] clean up --- torchtitan/components/tokenizer.py | 11 ++++++----- torchtitan/models/deepseek_v3/README.md | 6 ++++++ torchtitan/models/deepseek_v3/__init__.py | 4 ++-- torchtitan/models/deepseek_v3/model/args.py | 2 +- .../models/deepseek_v3/train_configs/debug_model.toml | 2 +- 5 files changed, 16 insertions(+), 9 deletions(-) create mode 100644 torchtitan/models/deepseek_v3/README.md diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index def7594ae..dbdc8d5df 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -11,6 +11,8 @@ from typing import Any, Optional from tokenizers import AddedToken, Tokenizer as HfTokenizer + +from torchtitan.config_manager import JobConfig from typing_extensions import override @@ -21,12 +23,10 @@ def __init__(self): self.eos_id = 0 @abstractmethod - def encode(self, *args, **kwargs) -> list[int]: - ... + def encode(self, *args, **kwargs) -> list[int]: ... @abstractmethod - def decode(self, *args, **kwargs) -> str: - ... + def decode(self, *args, **kwargs) -> str: ... @property def n_words(self) -> int: @@ -406,7 +406,7 @@ def id_to_token(self, token_id: int) -> Optional[str]: return self.tokenizer.id_to_token(token_id) -def build_hf_tokenizer(tokenizer_path: str) -> HuggingFaceTokenizer: +def build_hf_tokenizer(job_config: JobConfig) -> HuggingFaceTokenizer: """ Builds a HuggingFaceTokenizer from the specified path. @@ -421,5 +421,6 @@ def build_hf_tokenizer(tokenizer_path: str) -> HuggingFaceTokenizer: Returns: tokenizer (HuggingFaceTokenizer): Loaded tokenizer instance with intelligent BOS/EOS handling """ + tokenizer_path = job_config.model.tokenizer_path tokenizer = HuggingFaceTokenizer(tokenizer_path) return tokenizer diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md new file mode 100644 index 000000000..e0271bf82 --- /dev/null +++ b/torchtitan/models/deepseek_v3/README.md @@ -0,0 +1,6 @@ +Download tokenizer: + +``` +# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json) +python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3 +``` diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 918215b73..3ab9fc3c2 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -8,8 +8,8 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers from torchtitan.protocols.train_spec import register_train_spec, TrainSpec @@ -120,7 +120,7 @@ build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, - build_tokenizer_fn=build_tiktoken_tokenizer, + build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, ) ) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 769268999..ea469c672 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -97,7 +97,7 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non """ Update the model_config config from the given job config. """ - self.vocab_size = tokenizer.n_words + self.vocab_size = tokenizer.vocab_size self.max_seq_len = job_config.training.seq_len def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 66db16f9b..d160d02db 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -24,7 +24,7 @@ enable_wandb = false name = "deepseek_v3" flavor = "debugmodel" # test tokenizer.model, for debug purpose only -tokenizer_path = "./tests/assets/test_tiktoken.model" +tokenizer_path = "./assets/tokenizer/DeepSeek-V3" # converters = ["float8"] [optimizer] From 8b9552511cf03735564c5dcd295b6e03ac011de9 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 2 Jul 2025 14:11:36 -0700 Subject: [PATCH 12/12] lint --- torchtitan/components/tokenizer.py | 10 +++++----- .../deepseek_v3/train_configs/deepseek_v3_16b.toml | 8 ++++---- torchtitan/train.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index dbdc8d5df..abeb29219 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -12,7 +12,6 @@ from tokenizers import AddedToken, Tokenizer as HfTokenizer -from torchtitan.config_manager import JobConfig from typing_extensions import override @@ -23,10 +22,12 @@ def __init__(self): self.eos_id = 0 @abstractmethod - def encode(self, *args, **kwargs) -> list[int]: ... + def encode(self, *args, **kwargs) -> list[int]: + ... @abstractmethod - def decode(self, *args, **kwargs) -> str: ... + def decode(self, *args, **kwargs) -> str: + ... @property def n_words(self) -> int: @@ -406,7 +407,7 @@ def id_to_token(self, token_id: int) -> Optional[str]: return self.tokenizer.id_to_token(token_id) -def build_hf_tokenizer(job_config: JobConfig) -> HuggingFaceTokenizer: +def build_hf_tokenizer(tokenizer_path: str) -> HuggingFaceTokenizer: """ Builds a HuggingFaceTokenizer from the specified path. @@ -421,6 +422,5 @@ def build_hf_tokenizer(job_config: JobConfig) -> HuggingFaceTokenizer: Returns: tokenizer (HuggingFaceTokenizer): Loaded tokenizer instance with intelligent BOS/EOS handling """ - tokenizer_path = job_config.model.tokenizer_path tokenizer = HuggingFaceTokenizer(tokenizer_path) return tokenizer diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 19267f036..ad043827b 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -13,7 +13,7 @@ enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" [metrics] -log_freq = 1 +log_freq = 10 disable_color_printing = false enable_tensorboard = false save_tb_folder = "tb" @@ -23,7 +23,7 @@ enable_wandb = false name = "deepseek_v3" flavor = "16B" # test tokenizer.model, for debug purpose only -tokenizer_path = "./tests/assets/test_tiktoken.model" +tokenizer_path = "./assets/tokenizer/DeepSeek-V3" # converters = ["float8"] [optimizer] @@ -41,7 +41,7 @@ lr_min = 0.0 local_batch_size = 16 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 100 compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) @@ -49,7 +49,7 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 enable_async_tensor_parallel = false [checkpoint] diff --git a/torchtitan/train.py b/torchtitan/train.py index e6a1ffa7d..f4b4062d8 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -126,7 +126,7 @@ def __init__(self, job_config: JobConfig): # build dataloader tokenizer = ( - self.train_spec.build_tokenizer_fn(job_config) + self.train_spec.build_tokenizer_fn(job_config.model.tokenizer_path) if self.train_spec.build_tokenizer_fn is not None else None )