|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +# |
| 7 | +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. |
| 8 | + |
| 9 | + |
| 10 | +from dataclasses import dataclass |
| 11 | +from typing import Literal |
| 12 | + |
| 13 | +from torch import nn |
| 14 | + |
| 15 | +from torchtitan.components.tokenizer import Tokenizer |
| 16 | +from torchtitan.config_manager import JobConfig |
| 17 | +from torchtitan.protocols.train_spec import BaseModelArgs |
| 18 | + |
| 19 | + |
| 20 | +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py |
| 21 | +@dataclass |
| 22 | +class DeepseekV3ModelArgs(BaseModelArgs): |
| 23 | + """ |
| 24 | + Data class for defining model arguments and hyperparameters. |
| 25 | +
|
| 26 | + Attributes: |
| 27 | + max_batch_size (int): Maximum batch size. |
| 28 | + max_seq_len (int): Maximum sequence length. |
| 29 | + dtype (Literal["bf16", "fp8"]): Data type for computations. |
| 30 | + vocab_size (int): Vocabulary size. |
| 31 | + dim (int): Model dimension. |
| 32 | + inter_dim (int): Intermediate dimension for MLP layers. |
| 33 | + moe_inter_dim (int): Intermediate dimension for MoE layers. |
| 34 | + n_layers (int): Number of transformer layers. |
| 35 | + n_dense_layers (int): Number of dense layers in the model. |
| 36 | + n_heads (int): Number of attention heads. |
| 37 | + n_routed_experts (int): Number of routed experts for MoE layers. |
| 38 | + n_shared_experts (int): Number of shared experts for MoE layers. |
| 39 | + n_activated_experts (int): Number of activated experts in MoE layers. |
| 40 | + n_expert_groups (int): Number of expert groups. |
| 41 | + n_limited_groups (int): Number of limited groups for MoE routing. |
| 42 | + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. |
| 43 | + route_scale (float): Scaling factor for routing scores. |
| 44 | + use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers. |
| 45 | + load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers. |
| 46 | + q_lora_rank (int): LoRA rank for query projections. |
| 47 | + kv_lora_rank (int): LoRA rank for key-value projections. |
| 48 | + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. |
| 49 | + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. |
| 50 | + v_head_dim (int): Dimension for value projections. |
| 51 | + original_seq_len (int): Original sequence length. |
| 52 | + rope_theta (float): Base for rotary positional encoding. |
| 53 | + rope_factor (float): Scaling factor for extended sequence lengths. |
| 54 | + beta_fast (int): Fast beta correction factor. |
| 55 | + beta_slow (int): Slow beta correction factor. |
| 56 | + mscale (float): Scaling factor for extended attention. |
| 57 | + """ |
| 58 | + |
| 59 | + max_batch_size: int = 8 |
| 60 | + max_seq_len: int = 4096 * 4 |
| 61 | + dtype: Literal["bf16", "fp8"] = "bf16" |
| 62 | + vocab_size: int = 102400 |
| 63 | + dim: int = 2048 |
| 64 | + inter_dim: int = 10944 |
| 65 | + moe_inter_dim: int = 1408 |
| 66 | + n_layers: int = 27 |
| 67 | + n_dense_layers: int = 1 |
| 68 | + n_heads: int = 16 |
| 69 | + norm_eps: float = 1e-5 # eps used for RMSNorm |
| 70 | + # MoE |
| 71 | + n_routed_experts: int = 64 |
| 72 | + n_shared_experts: int = 2 |
| 73 | + n_activated_experts: int = 6 |
| 74 | + n_expert_groups: int = 1 |
| 75 | + n_limited_groups: int = 1 |
| 76 | + score_func: Literal["softmax", "sigmoid"] = "softmax" |
| 77 | + route_scale: float = 1.0 |
| 78 | + use_grouped_mm: bool = False |
| 79 | + load_balance_coeff: float | None = 1e-3 |
| 80 | + # Multi-Head Latent Attention (MLA) |
| 81 | + q_lora_rank: int = 0 |
| 82 | + kv_lora_rank: int = 512 |
| 83 | + qk_nope_head_dim: int = 128 |
| 84 | + qk_rope_head_dim: int = 64 |
| 85 | + v_head_dim: int = 128 |
| 86 | + use_flex_attn: bool = False |
| 87 | + attn_mask_type: str = "causal" |
| 88 | + # yarn |
| 89 | + original_seq_len: int = 4096 |
| 90 | + rope_theta: float = 10000.0 |
| 91 | + rope_factor: float = 40 |
| 92 | + beta_fast: int = 32 |
| 93 | + beta_slow: int = 1 |
| 94 | + mscale: float = 1.0 |
| 95 | + |
| 96 | + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: |
| 97 | + """ |
| 98 | + TODO: Placeholder for now |
| 99 | + """ |
| 100 | + pass |
| 101 | + |
| 102 | + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: |
| 103 | + """ |
| 104 | + TODO: Placeholder for now |
| 105 | + """ |
| 106 | + return 0, 0 |
0 commit comments