Skip to content

[DSv3] Compile support for single GPU #1364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions .github/CODEOWNERS

This file was deleted.

1 change: 1 addition & 0 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.deepseek_v3 # noqa: F401
import torchtitan.models.llama3 # noqa: F401
126 changes: 126 additions & 0 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer

from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize import parallelize_deepseekv3
from .model.args import DeepSeekV3ModelArgs
from .model.model import DeepSeekV3Model

__all__ = [
"parallelize_deepseekv3",
"DeepseekV3ModelArgs",
"DeepseekV3Model",
"deepseekv3_configs",
]


deepseekv3_configs = {
"debugmodel": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=256,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=3,
n_dense_layers=1,
n_heads=16,
n_routed_experts=8,
n_shared_experts=2,
n_activated_experts=3,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=2048,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=27,
n_dense_layers=1,
n_heads=16,
n_routed_experts=64,
n_shared_experts=2,
n_activated_experts=6,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
),
"236B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=5120,
inter_dim=12288,
moe_inter_dim=1536,
n_layers=60,
n_dense_layers=1,
n_heads=128,
n_routed_experts=160,
n_shared_experts=2,
n_activated_experts=6,
n_expert_groups=8,
n_limited_groups=3,
route_scale=16.0,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
),
"671B": DeepSeekV3ModelArgs(
vocab_size=129280,
dim=7168,
inter_dim=18432,
moe_inter_dim=2048,
n_layers=61,
n_dense_layers=3,
n_heads=128,
n_routed_experts=256,
n_shared_experts=1,
n_activated_experts=8,
n_expert_groups=8,
n_limited_groups=4,
route_scale=2.5,
score_func="sigmoid",
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
dtype="fp8",
),
}


register_train_spec(
TrainSpec(
name="deepseek_v3",
cls=DeepSeekV3Model,
config=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_tiktoken_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
75 changes: 75 additions & 0 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.models.deepseek_v3.model.moe import MoE
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp
from torchtitan.tools.logging import logger


def apply_compile(model: nn.Module):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
# Fail loudly if we exceed the recompile limit (default 8)
torch._dynamo.config.fail_on_recompile_limit_hit = True

for layer_id, transformer_block in model.layers.named_children():
fullgraph = True
if isinstance(transformer_block.moe, MoE):
# Allow graph break for MoE layers
fullgraph = False
transformer_block = torch.compile(transformer_block, fullgraph=fullgraph)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")


def parallelize_deepseekv3(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
apply_compile(model)

dp_mesh: DeviceMesh | None = None
if (
parallel_dims.dp_shard_enabled
): # apply FSDP or HSDP, potentially with Context Parallel
if parallel_dims.dp_replicate_enabled:
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
else:
dp_mesh_dim_names = ("dp_shard",)
dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]

apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
)

if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")

return model
156 changes: 156 additions & 0 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from typing import Literal

from torch import nn

from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.protocols.train_spec import BaseModelArgs
from torchtitan.tools.logging import logger


# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
@dataclass
class DeepSeekV3ModelArgs(BaseModelArgs):
"""
Data class for defining model arguments and hyperparameters.

Attributes:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
moe_inter_dim (int): Intermediate dimension for MoE layers.
n_layers (int): Number of transformer layers.
n_dense_layers (int): Number of dense layers in the model.
n_heads (int): Number of attention heads.
n_routed_experts (int): Number of routed experts for MoE layers.
n_shared_experts (int): Number of shared experts for MoE layers.
n_activated_experts (int): Number of activated experts in MoE layers.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
route_scale (float): Scaling factor for routing scores.
use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers.
load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
"""

max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
moe_inter_dim: int = 1408
n_layers: int = 27
n_dense_layers: int = 1
n_heads: int = 16
norm_eps: float = 1e-5 # eps used for RMSNorm
# MoE
n_routed_experts: int = 64
n_shared_experts: int = 2
n_activated_experts: int = 6
n_expert_groups: int = 1
n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1.0
use_grouped_mm: bool = False
load_balance_coeff: float | None = 1e-3
# Multi-Head Latent Attention (MLA)
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
use_flex_attn: bool = False
attn_mask_type: str = "causal"
# yarn
original_seq_len: int = 4096
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.0

def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
"""
Update the model_config config from the given job config.
"""
self.vocab_size = tokenizer.n_words
self.max_seq_len = job_config.training.seq_len

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
"""
Adopted from llama4 implementation.
"""
nparams_embedding = 0
nparams_moe_router = 0
nparams_shared_expert = 0
nparams_experts = 0
nparams_dense = 0

for name, p in model.named_parameters():
if "embedding" in name:
nparams_embedding += p.numel()
nparams_dense += p.numel()
elif "moe.shared_expert" in name:
nparams_shared_expert += p.numel()
elif "moe.router" in name:
nparams_moe_router += p.numel()
elif "moe.experts" in name:
nparams_experts += p.numel()
else:
nparams_dense += p.numel()

nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
nparams = nparams_dense + nparams_sparse
nparams_sparse_active = (
nparams_moe_router
+ nparams_shared_expert
+ nparams_experts * self.n_activated_experts // self.n_routed_experts
)

logger.info(
f"Total parameter count: dense {nparams_dense:,}, "
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
)

l, h, q, t = (
self.n_layers,
self.n_heads,
self.dim // self.n_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
# 2. the flash attention does 1 more matmul recomputation in the backward
# but recomputation should not be counted in calculating MFU (+0)
# 3. each matmul performs 1 multiplication and 1 addition (*2)
# 4. we follow the convention and do not account for sparsity in causal attention
num_flops_per_token = (
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
+ 12 * l * h * q * t
)

return nparams, num_flops_per_token
Loading