From b433f34af2b8449b84afe6851763bf1eceb9b06d Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 17 Jun 2025 17:27:44 -0700 Subject: [PATCH 1/6] remove CODEOWNERS --- .github/CODEOWNERS | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 000cc1af1..000000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,10 +0,0 @@ -# This is a CODEOWNERS file. -# Each line is a file pattern followed by one or more owners. - -# These owners will be the default owners for everything in -# the repo. Unless a later match takes precedence, -# they will be requested for review when someone opens a pull request. -* @tianyu-l @fegin @wwwjn @wconstab - -# Exclude the experiments directory by adding a pattern without owners -/torchtitan/experiments/ From a5d720eecfd03a3232dc8e2e368bad470e54524f Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Wed, 18 Jun 2025 12:21:42 -0700 Subject: [PATCH 2/6] Implement Deepseek-V3 model skeleton (#1315) ## Contents 1. Attention module 2. MoE module (note: I only implemented the naive routing, not the "node limit routing" strategy) 3. Deepseek-V3 model Reference: 1. Deepseek-ai: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py 4. Huggingface: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py 5. torchtitan/experiment/deepseek-v3 6. torchtitan/experiment/llama4 ## TODO - [ ] Further clean up the DeepseekV3ModelArgs class, remove unused model args - [ ] Test forward pass w/ torchtitan --- torchtitan/models/deepseek-v3/model/args.py | 106 ++++++ torchtitan/models/deepseek-v3/model/model.py | 337 ++++++++++++++++++ torchtitan/models/deepseek-v3/model/moe.py | 339 +++++++++++++++++++ 3 files changed, 782 insertions(+) create mode 100644 torchtitan/models/deepseek-v3/model/args.py create mode 100644 torchtitan/models/deepseek-v3/model/model.py create mode 100644 torchtitan/models/deepseek-v3/model/moe.py diff --git a/torchtitan/models/deepseek-v3/model/args.py b/torchtitan/models/deepseek-v3/model/args.py new file mode 100644 index 000000000..845d6b83e --- /dev/null +++ b/torchtitan/models/deepseek-v3/model/args.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass +from typing import Literal + +from torch import nn + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs + + +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +@dataclass +class DeepseekV3ModelArgs(BaseModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers. + load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + mscale (float): Scaling factor for extended attention. + """ + + max_batch_size: int = 8 + max_seq_len: int = 4096 * 4 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE + n_routed_experts: int = 64 + n_shared_experts: int = 2 + n_activated_experts: int = 6 + n_expert_groups: int = 1 + n_limited_groups: int = 1 + score_func: Literal["softmax", "sigmoid"] = "softmax" + route_scale: float = 1.0 + use_grouped_mm: bool = False + load_balance_coeff: float | None = 1e-3 + # Multi-Head Latent Attention (MLA) + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + use_flex_attn: bool = False + attn_mask_type: str = "causal" + # yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1.0 + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + """ + TODO: Placeholder for now + """ + pass + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + """ + TODO: Placeholder for now + """ + return 0, 0 diff --git a/torchtitan/models/deepseek-v3/model/model.py b/torchtitan/models/deepseek-v3/model/model.py new file mode 100644 index 000000000..dd6c44319 --- /dev/null +++ b/torchtitan/models/deepseek-v3/model/model.py @@ -0,0 +1,337 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torchtitan.models.attention import build_attention +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import DeepseekV3ModelArgs +from .moe import MoE + + +# Adopted from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 +def precompute_freqs_cis(args: DeepseekV3ModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (DeepseekV3ModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.qk_rope_head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + + def find_correction_dim( + num_rotations: float, dim: int, base: float, max_seq_len: int + ) -> float: + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int + ) -> Tuple[int, int]: + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Basic RoPE frequency calculation + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + # YaRN scaling for extended context + if seqlen > args.original_seq_len: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, args.original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + t = torch.arange(seqlen) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__(self, model_args: DeepseekV3ModelArgs): + super().__init__() + self.dim = model_args.dim + self.n_heads = model_args.n_heads + self.q_lora_rank = model_args.q_lora_rank + self.kv_lora_rank = model_args.kv_lora_rank + self.qk_nope_head_dim = model_args.qk_nope_head_dim + self.qk_rope_head_dim = model_args.qk_rope_head_dim + self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim + self.v_head_dim = model_args.v_head_dim + + if self.q_lora_rank == 0: + self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim) + else: + self.wq_a = nn.Linear(self.dim, self.q_lora_rank) + self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps) + self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim) + self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) + self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps) + self.wkv_b = nn.Linear( + self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim) + ) + self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim) + self.softmax_scale = self.qk_head_dim**-0.5 + + if model_args.max_seq_len > model_args.original_seq_len: + mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + + # Query projection + if self.q_lora_rank == 0: + q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim) + else: + q = self.wq_b(self.q_norm(self.wq_a(x))) + + q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) + + # Key-value projection + kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) + kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = apply_rotary_emb( + k_pe.unsqueeze(2), freqs_cis + ) # (bsz, seqlen, 1, qk_rope_head_dim) + + kv = self.wkv_b( + self.kv_norm(kv) + ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) + kv = kv.view(bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat( + [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 + ) # (bsz, seqlen, n_heads, qk_head_dim) + + # TODO: Need to pass softmax_scale to sdpa() interface. + # For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa + # https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17 + output = self.sdpa(q, k, v) + + # Reshape and project output + output = output.transpose(1, 2) # (bsz, seqlen, n_heads, v_head_dim) + output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) + return self.wo(output) # (bsz, seqlen, dim) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: DeepseekV3ModelArgs): + + super().__init__() + self.attention = Attention(model_args) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn = ( + FeedForward(model_args.dim, model_args.inter_dim) + if layer_id < model_args.n_dense_layers + else MoE(model_args) + ) + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.ffn(self.ffn_norm(x)) + return x + + +class Transformer(nn.Module, ModelProtocol): + """ + Deepseek-V3 Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: DeepseekV3ModelArgs): + super().__init__() + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.register_buffer( + "freqs_cis", precompute_freqs_cis(model_args), persistent=False + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(model_args.n_layers): + self.layers.append( + TransformerBlock(layer_id=layer_id, model_args=model_args) + ) + self.norm = nn.RMSNorm(model_args.dim) + self.output = nn.Linear( + model_args.dim, model_args.vocab_size, dtype=torch.get_default_dtype() + ) + self.init_weights() + + def forward(self, tokens: torch.Tensor): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + h = self.tok_embeddings(tokens) + for layer in self.layers: + h = layer(h, self.freqs_cis) + h = self.norm(h)[:, -1] + output = self.output(h) + return output + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + pass diff --git a/torchtitan/models/deepseek-v3/model/moe.py b/torchtitan/models/deepseek-v3/model/moe.py new file mode 100644 index 000000000..b224b1097 --- /dev/null +++ b/torchtitan/models/deepseek-v3/model/moe.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn + +from .args import DeepseekV3ModelArgs + + +# Reference: torchtitan/experiments/llama4/model/ +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.use_grouped_mm = use_grouped_mm + + def forward( + self, + x: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor | list[int] | None = None, + ) -> torch.Tensor: + # TODO: keeping this for loop implementation for comparison + # and readability, will remove later + if not self.use_grouped_mm: + if num_local_tokens_per_expert is not None: + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x, + split_size_or_sections=num_local_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + w1, w2, w3 = ( + self.w1[expert_idx], + self.w2[expert_idx], + self.w3[expert_idx], + ) + h = F.silu(torch.matmul(x_expert, w1)) + h = h * torch.matmul(x_expert, w3) + h = torch.matmul(h, w2) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, self.w1)) + h = h * torch.bmm(x, self.w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, self.w2) + + return out + + # grouped mm implementation + if num_local_tokens_per_expert is not None: + # https://github.com/pytorch/pytorch/pull/150374 + # NOTE: torch._gouped_mm requires bf16 dtypes + # and shapes to be multiple of 8 + offsets = torch.cumsum( + num_local_tokens_per_expert, dim=0, dtype=torch.int32 + ) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + assert ( + x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 + ), "torch._grouped_mm only supports bf16 dtypes" + h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) + h = h * torch._grouped_mm(x, self.w3, offs=offsets) + out = torch._grouped_mm(h, self.w2, offs=offsets) + + return out + + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + use_sigmoid: bool = False, + route_sclaing_factor: float = 1.0, + ): + super().__init__() + + self.num_experts = num_experts + self.top_k = top_k + self.use_sigmoid = use_sigmoid + self.route_sclaing_factor = route_sclaing_factor + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + + def forward( + self, x: torch.Tensor, expert_bias: torch.Tensor = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + TODO: We haven't implement the group-based routing (node limit routing), + and currently EP is not supporting node limit routing yet. + + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_local_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + scores = F.linear(x.type, self.weight, None) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.use_sigmoid: + scores = torch.sigmoid(scores.to(torch.float32)) + else: + scores = F.softmax(scores.to(torch.float32), dim=1) + + # top scores shape (bs*slen, top_k) + # NOTE: The expert_bias is only used for routing. The gating value + # top_scores is still derived from the original scores. + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_local_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + top_scores = ( + top_scores * self.route_sclaing_factor + ) # must multiply the scaling factor + + return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert + + +class MoE(nn.Module): + def __init__(self, model_args: DeepseekV3ModelArgs): + + super().__init__() + dim = model_args.dim + + num_experts = model_args.n_routed_experts + hidden_dim = model_args.moe_inter_dim + top_k = model_args.n_activated_experts + route_scaling_factor = model_args.route_scale + + self.use_grouped_mm = model_args.use_grouped_mm + self.experts = GroupedExperts( + dim=dim, + hidden_dim=hidden_dim, + num_experts=num_experts, + use_grouped_mm=self.use_grouped_mm, + ) + self.router = TokenChoiceTopKRouter( + num_experts=num_experts, + top_k=top_k, + use_sigmoid=model_args.score_func == "sigmoid", + route_sclaing_factor=route_scaling_factor, + ) + self.shared_expert = ( + # Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py#L517 + GroupedExperts( + dim=dim, + hidden_dim=hidden_dim * model_args.n_shared_experts, + num_experts=1, + use_grouped_mm=self.use_grouped_mm, + ) + if model_args.n_shared_experts > 0 + else None + ) + + # auxiliary-loss-free load balancing + self.load_balance_coeff = model_args.load_balance_coeff + # the fields below are defined even when load_balance_coeff is None + # to make initialization and checkpointing code simpler + self.register_buffer( + "expert_bias", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + + # NOTE: forward hook, forward pre hook, or backward pre hook + # would conflict with activation checkpointing + if self.load_balance_coeff is not None and self.load_balance_coeff > 0: + self.register_full_backward_hook(self._update_expert_bias) + + # TODO: double check the bias update logic. It aligns with the paper. + def _update_expert_bias(self, *_): + expert_bias_delta = self.load_balance_coeff * torch.sign( + self.tokens_per_expert.mean() - self.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + self.expert_bias.add_(expert_bias_delta) + + self.tokens_per_expert.zero_() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_local_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_local_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + + # will be used to update the expert bias for load balancing + self.tokens_per_expert += num_local_tokens_per_expert + + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + + if self.use_grouped_mm: + # NOTE: In order to use torch._grouped_mm, we need to make sure + # the number of tokens each expert gets is a multiple of 16. + # The following kernel helps achieve this via padding, without + # incurring synchronization between device and host. + from torchtitan.experiments.kernels.moe.indices import ( + generate_permute_indices, + ) + + ALIGN_SIZE_M = 16 + + with torch.no_grad(): + ( + permuted_indices, + num_local_tokens_per_expert, + _, + ) = generate_permute_indices( + num_local_tokens_per_expert, + self.experts.num_experts, + 1, + ALIGN_SIZE_M, + ) + token_indices = torch.vstack( + (token_indices, token_indices.new_zeros((dim))) + ) + token_indices = token_indices[permuted_indices, :] + routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) + routed_input = routed_input[permuted_indices, :] + else: + # NOTE: this would incur a synchronization between device and host + num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_local_tokens_per_expert) + routed_output = routed_output * top_scores.unsqueeze(-1) + + # shared expert + if self.shared_expert is not None: + out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( + bs * slen, dim + ) + else: + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights( + self, + init_std: float, + buffer_device: torch.device, + ): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) + if self.shared_expert is not None: + self.shared_expert.init_weights(init_std) + + with torch.device(buffer_device): + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) From c76a079ba5e1ea0f847fc5c66b7fc408046ba615 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Mon, 23 Jun 2025 10:55:32 -0700 Subject: [PATCH 3/6] [DSV3] Forward and backward pass for single GPU (#1320) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Command to run: `NGPU=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh` ## Context 1. Added model args for 4 model settings, and training config for debug model 2. Debugged the forward pass, and the backward pass works out of pocket. 3. Reused c4-test dataset, and tiktokenizer from llama3 model for current testing ![Screenshot 2025-06-20 at 11 52 49 AM](https://github.com/user-attachments/assets/81d938a2-9a85-4e8c-b8e1-7f9510d785c2) --- torchtitan/models/__init__.py | 1 + torchtitan/models/deepseek_v3/__init__.py | 125 ++++++++++++++++++ .../models/deepseek_v3/infra/parallelize.py | 23 ++++ .../model/args.py | 63 ++++++++- .../model/model.py | 36 ++--- .../{deepseek-v3 => deepseek_v3}/model/moe.py | 13 +- .../train_configs/debug_model.toml | 69 ++++++++++ 7 files changed, 303 insertions(+), 27 deletions(-) create mode 100644 torchtitan/models/deepseek_v3/__init__.py create mode 100644 torchtitan/models/deepseek_v3/infra/parallelize.py rename torchtitan/models/{deepseek-v3 => deepseek_v3}/model/args.py (62%) rename torchtitan/models/{deepseek-v3 => deepseek_v3}/model/model.py (92%) rename torchtitan/models/{deepseek-v3 => deepseek_v3}/model/moe.py (97%) create mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_model.toml diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index fd5aa42c6..378f88665 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -7,4 +7,5 @@ # Import the built-in models here so that the corresponding register_model_spec() # will be called. +import torchtitan.models.deepseek_v3 # noqa: F401 import torchtitan.models.llama3 # noqa: F401 diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py new file mode 100644 index 000000000..8a21e53dd --- /dev/null +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize import parallelize_deepseekv3 +from .model.args import DeepSeekV3ModelArgs +from .model.model import DeepSeekV3Model + +__all__ = [ + "parallelize_deepseekv3", + "DeepseekV3ModelArgs", + "DeepseekV3Model", + "deepseekv3_configs", +] + + +deepseekv3_configs = { + "debugmodel": DeepSeekV3ModelArgs( + vocab_size=102400, + dim=256, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=3, + n_dense_layers=1, + n_heads=16, + n_routed_experts=8, + n_shared_experts=2, + n_activated_experts=3, + route_scale=1.0, + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), + "16B": DeepSeekV3ModelArgs( + vocab_size=102400, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=27, + n_dense_layers=1, + n_heads=16, + n_routed_experts=64, + n_shared_experts=2, + n_activated_experts=6, + route_scale=1.0, + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), + "236B": DeepSeekV3ModelArgs( + vocab_size=102400, + dim=5120, + inter_dim=12288, + moe_inter_dim=1536, + n_layers=60, + n_dense_layers=1, + n_heads=128, + n_routed_experts=160, + n_shared_experts=2, + n_activated_experts=6, + n_expert_groups=8, + n_limited_groups=3, + route_scale=16.0, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ), + "671B": DeepSeekV3ModelArgs( + vocab_size=129280, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=61, + n_dense_layers=3, + n_heads=128, + n_routed_experts=256, + n_shared_experts=1, + n_activated_experts=8, + n_expert_groups=8, + n_limited_groups=4, + route_scale=2.5, + score_func="sigmoid", + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + dtype="fp8", + ), +} + + +register_train_spec( + TrainSpec( + name="deepseek_v3", + cls=DeepSeekV3Model, + config=deepseekv3_configs, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=None, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py new file mode 100644 index 000000000..f8090683c --- /dev/null +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch.nn as nn + +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + + +def parallelize_deepseekv3( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + # TODO: Add support for parallelizing the model, this is a placeholder function for now + return model diff --git a/torchtitan/models/deepseek-v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py similarity index 62% rename from torchtitan/models/deepseek-v3/model/args.py rename to torchtitan/models/deepseek_v3/model/args.py index 845d6b83e..c0134bf54 100644 --- a/torchtitan/models/deepseek-v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -15,11 +15,12 @@ from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger # Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py @dataclass -class DeepseekV3ModelArgs(BaseModelArgs): +class DeepSeekV3ModelArgs(BaseModelArgs): """ Data class for defining model arguments and hyperparameters. @@ -53,7 +54,6 @@ class DeepseekV3ModelArgs(BaseModelArgs): rope_factor (float): Scaling factor for extended sequence lengths. beta_fast (int): Fast beta correction factor. beta_slow (int): Slow beta correction factor. - mscale (float): Scaling factor for extended attention. """ max_batch_size: int = 8 @@ -95,12 +95,63 @@ class DeepseekV3ModelArgs(BaseModelArgs): def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: """ - TODO: Placeholder for now + Update the model_config config from the given job config. """ - pass + self.vocab_size = tokenizer.n_words + self.max_seq_len = job_config.training.seq_len def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: """ - TODO: Placeholder for now + Adopted from llama4 implementation. """ - return 0, 0 + nparams_embedding = 0 + nparams_moe_router = 0 + nparams_shared_expert = 0 + nparams_experts = 0 + nparams_dense = 0 + + for name, p in model.named_parameters(): + print(name) + if "embedding" in name: + nparams_embedding += p.numel() + nparams_dense += p.numel() + elif "moe.shared_expert" in name: + nparams_shared_expert += p.numel() + elif "moe.router" in name: + nparams_moe_router += p.numel() + elif "moe.experts" in name: + nparams_experts += p.numel() + else: + nparams_dense += p.numel() + + nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams = nparams_dense + nparams_sparse + nparams_sparse_active = ( + nparams_moe_router + + nparams_shared_expert + + nparams_experts * self.n_activated_experts // self.n_routed_experts + ) + + logger.info( + f"Total parameter count: dense {nparams_dense:,}, " + f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = ( + 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + + 12 * l * h * q * t + ) + + return nparams, num_flops_per_token diff --git a/torchtitan/models/deepseek-v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py similarity index 92% rename from torchtitan/models/deepseek-v3/model/model.py rename to torchtitan/models/deepseek_v3/model/model.py index dd6c44319..c5ee02327 100644 --- a/torchtitan/models/deepseek-v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -13,17 +13,17 @@ from torchtitan.models.attention import build_attention from torchtitan.protocols.train_spec import ModelProtocol -from .args import DeepseekV3ModelArgs +from .args import DeepSeekV3ModelArgs from .moe import MoE -# Adopted from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 -def precompute_freqs_cis(args: DeepseekV3ModelArgs) -> torch.Tensor: +# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 +def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor: """ Precomputes frequency-based complex exponential values for rotary positional embeddings. Args: - args (DeepseekV3ModelArgs): Model arguments containing positional embedding parameters. + args (DeepSeekV3ModelArgs): Model arguments containing positional embedding parameters. Returns: torch.Tensor: Precomputed complex exponential values for positional embeddings. @@ -98,7 +98,7 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: # Basic RoPE frequency calculation freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - # YaRN scaling for extended context + # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. if seqlen > args.original_seq_len: low, high = find_correction_range( beta_fast, beta_slow, dim, base, args.original_seq_len @@ -106,8 +106,13 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: smooth = 1 - linear_ramp_factor(low, high, dim // 2) freqs = freqs / factor * (1 - smooth) + freqs * smooth + # Create position indices t = torch.arange(seqlen) + + # Outer product: [positions] × [frequencies] freqs = torch.outer(t, freqs) + + # Convert to complex exponentials: e^(i*freq*pos) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis @@ -135,7 +140,7 @@ class Attention(nn.Module): Multi-head attention (MLA) module. """ - def __init__(self, model_args: DeepseekV3ModelArgs): + def __init__(self, model_args: DeepSeekV3ModelArgs): super().__init__() self.dim = model_args.dim self.n_heads = model_args.n_heads @@ -264,13 +269,13 @@ class TransformerBlock(nn.Module): Transformer block with attention and feed-forward layers. """ - def __init__(self, layer_id: int, model_args: DeepseekV3ModelArgs): + def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): super().__init__() self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.ffn = ( + self.moe_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.moe = ( FeedForward(model_args.dim, model_args.inter_dim) if layer_id < model_args.n_dense_layers else MoE(model_args) @@ -288,16 +293,16 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): torch.Tensor: Output tensor with the same shape as the input. """ x = x + self.attention(self.attention_norm(x), freqs_cis) - x = x + self.ffn(self.ffn_norm(x)) + x = x + self.moe(self.moe_norm(x)) return x -class Transformer(nn.Module, ModelProtocol): +class DeepSeekV3Model(nn.Module, ModelProtocol): """ - Deepseek-V3 Transformer model with attention and feed-forward layers. + DeepSeek-V3 Transformer model with attention and feed-forward layers. """ - def __init__(self, model_args: DeepseekV3ModelArgs): + def __init__(self, model_args: DeepSeekV3ModelArgs): super().__init__() self.max_seq_len = model_args.max_seq_len self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) @@ -327,10 +332,11 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ h = self.tok_embeddings(tokens) + for layer in self.layers: h = layer(h, self.freqs_cis) - h = self.norm(h)[:, -1] - output = self.output(h) + h = self.norm(h) + output = self.output(h) # (batch_size, seq_len, dim) return output def init_weights(self, buffer_device: torch.device | None = None) -> None: diff --git a/torchtitan/models/deepseek-v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py similarity index 97% rename from torchtitan/models/deepseek-v3/model/moe.py rename to torchtitan/models/deepseek_v3/model/moe.py index b224b1097..3e17968e1 100644 --- a/torchtitan/models/deepseek-v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import nn -from .args import DeepseekV3ModelArgs +from .args import DeepSeekV3ModelArgs # Reference: torchtitan/experiments/llama4/model/ @@ -103,6 +103,7 @@ class TokenChoiceTopKRouter(nn.Module): def __init__( self, + dim: int, num_experts: int, top_k: int, use_sigmoid: bool = False, @@ -110,14 +111,13 @@ def __init__( ): super().__init__() + self.dim = dim self.num_experts = num_experts self.top_k = top_k self.use_sigmoid = use_sigmoid self.route_sclaing_factor = route_sclaing_factor - self.weight = nn.Parameter( - torch.empty((self.n_routed_experts, self.gating_dim)) - ) + self.weight = nn.Parameter(torch.empty((self.num_experts, self.dim))) def forward( self, x: torch.Tensor, expert_bias: torch.Tensor = None @@ -138,7 +138,7 @@ def forward( Number of tokens assigned to each expert with shape ``(num_experts,)``. """ # scores shape (bs*slen, num_experts) - scores = F.linear(x.type, self.weight, None) + scores = F.linear(x, self.weight, bias=None) # By default, sigmoid or softmax is performed in float32 to avoid loss explosion if self.use_sigmoid: @@ -176,7 +176,7 @@ def forward( class MoE(nn.Module): - def __init__(self, model_args: DeepseekV3ModelArgs): + def __init__(self, model_args: DeepSeekV3ModelArgs): super().__init__() dim = model_args.dim @@ -194,6 +194,7 @@ def __init__(self, model_args: DeepseekV3ModelArgs): use_grouped_mm=self.use_grouped_mm, ) self.router = TokenChoiceTopKRouter( + dim=dim, num_experts=num_experts, top_k=top_k, use_sigmoid=model_args.score_func == "sigmoid", diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml new file mode 100644 index 000000000..eddca8849 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -0,0 +1,69 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "DeepSeek-V3 debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel" +# test tokenizer.model, for debug purpose only +tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] From 50e923046839c5bbd663d9ec9f7f66aa720431b4 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Tue, 24 Jun 2025 13:18:16 -0700 Subject: [PATCH 4/6] [DSV3] Adding 16B model training config, Enable FSDP and AC on DSV3-16B model (#1330) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Context 1. Introduced a basic DSV3-16B model training config 2. Enabled FSDP/HSDP on DSV3-16B model training ## Performance Current profiler looks like this: The `to_copy` takes to long and needs to be optimized. The copy comes from dtype conversion in class MoE(): ```routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(x.dtype)``` With FSDP only: Screenshot 2025-06-23 at 2 10 20 PM --- torchtitan/models/deepseek_v3/__init__.py | 1 + .../models/deepseek_v3/infra/parallelize.py | 35 ++++++++-- torchtitan/models/deepseek_v3/model/args.py | 1 - torchtitan/models/deepseek_v3/model/model.py | 13 ++-- torchtitan/models/deepseek_v3/model/moe.py | 4 +- .../train_configs/deepseek_v3_16b.toml | 67 +++++++++++++++++++ 6 files changed, 110 insertions(+), 11 deletions(-) create mode 100644 torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 8a21e53dd..7eb16a1f3 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -11,6 +11,7 @@ from torchtitan.components.optimizer import build_optimizers from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer + from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_deepseekv3 diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index f8090683c..99338663f 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -4,13 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import torch.nn as nn - from torch.distributed.device_mesh import DeviceMesh -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp +from torchtitan.tools.logging import logger def parallelize_deepseekv3( @@ -19,5 +19,32 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - # TODO: Add support for parallelizing the model, this is a placeholder function for now + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + dp_mesh: DeviceMesh | None = None + if ( + parallel_dims.dp_shard_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + else: + dp_mesh_dim_names = ("dp_shard",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + return model diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index c0134bf54..09e882764 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -111,7 +111,6 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in nparams_dense = 0 for name, p in model.named_parameters(): - print(name) if "embedding" in name: nparams_embedding += p.numel() nparams_dense += p.numel() diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index c5ee02327..3eb0f2fbc 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -217,6 +217,10 @@ def forward( [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 ) # (bsz, seqlen, n_heads, qk_head_dim) + q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) + # TODO: Need to pass softmax_scale to sdpa() interface. # For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa # https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17 @@ -310,11 +314,10 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): "freqs_cis", precompute_freqs_cis(model_args), persistent=False ) - self.layers = torch.nn.ModuleList() + self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): - self.layers.append( - TransformerBlock(layer_id=layer_id, model_args=model_args) - ) + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.norm = nn.RMSNorm(model_args.dim) self.output = nn.Linear( model_args.dim, model_args.vocab_size, dtype=torch.get_default_dtype() @@ -333,7 +336,7 @@ def forward(self, tokens: torch.Tensor): """ h = self.tok_embeddings(tokens) - for layer in self.layers: + for layer in self.layers.values(): h = layer(h, self.freqs_cis) h = self.norm(h) output = self.output(h) # (batch_size, seq_len, dim) diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 3e17968e1..c9217c8be 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -307,7 +307,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_local_tokens_per_expert) - routed_output = routed_output * top_scores.unsqueeze(-1) + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + x.dtype + ) # shared expert if self.shared_expert is not None: diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml new file mode 100644 index 000000000..4f08fb098 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -0,0 +1,67 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "DeepSeek-V3 16B model training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "16B" +# test tokenizer.model, for debug purpose only +tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +local_batch_size = 32 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] From f847582e20deda38d151d06dfdf0b75398738063 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 1 Jul 2025 21:10:51 -0700 Subject: [PATCH 5/6] compile support for single gpu --- .../models/deepseek_v3/infra/parallelize.py | 25 +++++++++++++ torchtitan/models/deepseek_v3/model/moe.py | 36 +++++++++++++------ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 99338663f..5542b01df 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -4,15 +4,36 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp +from torchtitan.models.deepseek_v3.model.moe import MoE from torchtitan.tools.logging import logger +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + # Fail loudly if we exceed the recompile limit (default 8) + torch._dynamo.config.fail_on_recompile_limit_hit = True + + for layer_id, transformer_block in model.layers.named_children(): + fullgraph=True + if isinstance(transformer_block.moe, MoE): + # Allow graph break for MoE layers + fullgraph = False + transformer_block = torch.compile(transformer_block, fullgraph=fullgraph) + model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile") + + def parallelize_deepseekv3( model: nn.Module, world_mesh: DeviceMesh, @@ -22,6 +43,10 @@ def parallelize_deepseekv3( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + dp_mesh: DeviceMesh | None = None if ( parallel_dims.dp_shard_enabled diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index c9217c8be..295f1b3b2 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -11,6 +11,30 @@ from .args import DeepSeekV3ModelArgs +# This function is manually hidden from the compiler to: +# - hide inactive experts to avoid 0/1 specialization caused by zero-shaped tokens +# - mark_dynamic split outputs to avoid specializing on each expert's token shape +@torch._dynamo.disable +def split_tokens(x, num_local_tokens_per_expert): + splits = num_local_tokens_per_expert.tolist() + tokens_by_expert = torch.split( + x, + split_size_or_sections=splits, + dim=0, + ) + expert_idxs = [] + expert_tokens = [] + for i, split in enumerate(splits): + if split == 0: + # inactive expert, hide it from the compiler + continue + + torch._dynamo.mark_dynamic(tokens_by_expert[i], 0) + expert_idxs.append(i) + expert_tokens.append(tokens_by_expert[i]) + + return expert_idxs, expert_tokens + # Reference: torchtitan/experiments/llama4/model/ class GroupedExperts(nn.Module): def __init__( @@ -38,13 +62,9 @@ def forward( if num_local_tokens_per_expert is not None: # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x, - split_size_or_sections=num_local_tokens_per_expert, - dim=0, - ) + expert_idxs, expert_tokens = split_tokens(x, num_local_tokens_per_expert) out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): + for expert_idx, x_expert in zip(expert_idxs, expert_tokens): w1, w2, w3 = ( self.w1[expert_idx], self.w2[expert_idx], @@ -301,10 +321,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: token_indices = token_indices[permuted_indices, :] routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) routed_input = routed_input[permuted_indices, :] - else: - # NOTE: this would incur a synchronization between device and host - num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() - # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_local_tokens_per_expert) routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( From dd5b2472a91d0dda130831449079451ae4fc45c2 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 2 Jul 2025 09:22:27 -0700 Subject: [PATCH 6/6] lint --- torchtitan/models/deepseek_v3/infra/parallelize.py | 4 ++-- torchtitan/models/deepseek_v3/model/moe.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 5542b01df..c9a81464b 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -10,8 +10,8 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp from torchtitan.models.deepseek_v3.model.moe import MoE +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp from torchtitan.tools.logging import logger @@ -24,7 +24,7 @@ def apply_compile(model: nn.Module): torch._dynamo.config.fail_on_recompile_limit_hit = True for layer_id, transformer_block in model.layers.named_children(): - fullgraph=True + fullgraph = True if isinstance(transformer_block.moe, MoE): # Allow graph break for MoE layers fullgraph = False diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 295f1b3b2..9b0d9536f 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -35,6 +35,7 @@ def split_tokens(x, num_local_tokens_per_expert): return expert_idxs, expert_tokens + # Reference: torchtitan/experiments/llama4/model/ class GroupedExperts(nn.Module): def __init__( @@ -62,7 +63,9 @@ def forward( if num_local_tokens_per_expert is not None: # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) - expert_idxs, expert_tokens = split_tokens(x, num_local_tokens_per_expert) + expert_idxs, expert_tokens = split_tokens( + x, num_local_tokens_per_expert + ) out_experts_splits = [] for expert_idx, x_expert in zip(expert_idxs, expert_tokens): w1, w2, w3 = (