Skip to content

Commit 4eeb287

Browse files
wwwjnH-Huang
authored andcommitted
Implement Deepseek-V3 model skeleton (#1315)
## Contents 1. Attention module 2. MoE module (note: I only implemented the naive routing, not the "node limit routing" strategy) 3. Deepseek-V3 model Reference: 1. Deepseek-ai: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py 4. Huggingface: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py 5. torchtitan/experiment/deepseek-v3 6. torchtitan/experiment/llama4 ## TODO - [ ] Further clean up the DeepseekV3ModelArgs class, remove unused model args - [ ] Test forward pass w/ torchtitan
1 parent 362ccd6 commit 4eeb287

File tree

3 files changed

+782
-0
lines changed

3 files changed

+782
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
10+
from dataclasses import dataclass
11+
from typing import Literal
12+
13+
from torch import nn
14+
15+
from torchtitan.components.tokenizer import Tokenizer
16+
from torchtitan.config_manager import JobConfig
17+
from torchtitan.protocols.train_spec import BaseModelArgs
18+
19+
20+
# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
21+
@dataclass
22+
class DeepseekV3ModelArgs(BaseModelArgs):
23+
"""
24+
Data class for defining model arguments and hyperparameters.
25+
26+
Attributes:
27+
max_batch_size (int): Maximum batch size.
28+
max_seq_len (int): Maximum sequence length.
29+
dtype (Literal["bf16", "fp8"]): Data type for computations.
30+
vocab_size (int): Vocabulary size.
31+
dim (int): Model dimension.
32+
inter_dim (int): Intermediate dimension for MLP layers.
33+
moe_inter_dim (int): Intermediate dimension for MoE layers.
34+
n_layers (int): Number of transformer layers.
35+
n_dense_layers (int): Number of dense layers in the model.
36+
n_heads (int): Number of attention heads.
37+
n_routed_experts (int): Number of routed experts for MoE layers.
38+
n_shared_experts (int): Number of shared experts for MoE layers.
39+
n_activated_experts (int): Number of activated experts in MoE layers.
40+
n_expert_groups (int): Number of expert groups.
41+
n_limited_groups (int): Number of limited groups for MoE routing.
42+
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
43+
route_scale (float): Scaling factor for routing scores.
44+
use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers.
45+
load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers.
46+
q_lora_rank (int): LoRA rank for query projections.
47+
kv_lora_rank (int): LoRA rank for key-value projections.
48+
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
49+
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
50+
v_head_dim (int): Dimension for value projections.
51+
original_seq_len (int): Original sequence length.
52+
rope_theta (float): Base for rotary positional encoding.
53+
rope_factor (float): Scaling factor for extended sequence lengths.
54+
beta_fast (int): Fast beta correction factor.
55+
beta_slow (int): Slow beta correction factor.
56+
mscale (float): Scaling factor for extended attention.
57+
"""
58+
59+
max_batch_size: int = 8
60+
max_seq_len: int = 4096 * 4
61+
dtype: Literal["bf16", "fp8"] = "bf16"
62+
vocab_size: int = 102400
63+
dim: int = 2048
64+
inter_dim: int = 10944
65+
moe_inter_dim: int = 1408
66+
n_layers: int = 27
67+
n_dense_layers: int = 1
68+
n_heads: int = 16
69+
norm_eps: float = 1e-5 # eps used for RMSNorm
70+
# MoE
71+
n_routed_experts: int = 64
72+
n_shared_experts: int = 2
73+
n_activated_experts: int = 6
74+
n_expert_groups: int = 1
75+
n_limited_groups: int = 1
76+
score_func: Literal["softmax", "sigmoid"] = "softmax"
77+
route_scale: float = 1.0
78+
use_grouped_mm: bool = False
79+
load_balance_coeff: float | None = 1e-3
80+
# Multi-Head Latent Attention (MLA)
81+
q_lora_rank: int = 0
82+
kv_lora_rank: int = 512
83+
qk_nope_head_dim: int = 128
84+
qk_rope_head_dim: int = 64
85+
v_head_dim: int = 128
86+
use_flex_attn: bool = False
87+
attn_mask_type: str = "causal"
88+
# yarn
89+
original_seq_len: int = 4096
90+
rope_theta: float = 10000.0
91+
rope_factor: float = 40
92+
beta_fast: int = 32
93+
beta_slow: int = 1
94+
mscale: float = 1.0
95+
96+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
97+
"""
98+
TODO: Placeholder for now
99+
"""
100+
pass
101+
102+
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
103+
"""
104+
TODO: Placeholder for now
105+
"""
106+
return 0, 0

0 commit comments

Comments
 (0)