Skip to content

Commit c76a079

Browse files
committed
[DSV3] Forward and backward pass for single GPU (#1320)
Command to run: `NGPU=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh` ## Context 1. Added model args for 4 model settings, and training config for debug model 2. Debugged the forward pass, and the backward pass works out of pocket. 3. Reused c4-test dataset, and tiktokenizer from llama3 model for current testing ![Screenshot 2025-06-20 at 11 52 49 AM](https://github.com/user-attachments/assets/81d938a2-9a85-4e8c-b8e1-7f9510d785c2)
1 parent a5d720e commit c76a079

File tree

7 files changed

+303
-27
lines changed

7 files changed

+303
-27
lines changed

torchtitan/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77

88
# Import the built-in models here so that the corresponding register_model_spec()
99
# will be called.
10+
import torchtitan.models.deepseek_v3 # noqa: F401
1011
import torchtitan.models.llama3 # noqa: F401
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers
12+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13+
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
14+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
15+
16+
from .infra.parallelize import parallelize_deepseekv3
17+
from .model.args import DeepSeekV3ModelArgs
18+
from .model.model import DeepSeekV3Model
19+
20+
__all__ = [
21+
"parallelize_deepseekv3",
22+
"DeepseekV3ModelArgs",
23+
"DeepseekV3Model",
24+
"deepseekv3_configs",
25+
]
26+
27+
28+
deepseekv3_configs = {
29+
"debugmodel": DeepSeekV3ModelArgs(
30+
vocab_size=102400,
31+
dim=256,
32+
inter_dim=10944,
33+
moe_inter_dim=1408,
34+
n_layers=3,
35+
n_dense_layers=1,
36+
n_heads=16,
37+
n_routed_experts=8,
38+
n_shared_experts=2,
39+
n_activated_experts=3,
40+
route_scale=1.0,
41+
q_lora_rank=0,
42+
kv_lora_rank=512,
43+
qk_nope_head_dim=128,
44+
qk_rope_head_dim=64,
45+
v_head_dim=128,
46+
mscale=0.70,
47+
),
48+
"16B": DeepSeekV3ModelArgs(
49+
vocab_size=102400,
50+
dim=2048,
51+
inter_dim=10944,
52+
moe_inter_dim=1408,
53+
n_layers=27,
54+
n_dense_layers=1,
55+
n_heads=16,
56+
n_routed_experts=64,
57+
n_shared_experts=2,
58+
n_activated_experts=6,
59+
route_scale=1.0,
60+
q_lora_rank=0,
61+
kv_lora_rank=512,
62+
qk_nope_head_dim=128,
63+
qk_rope_head_dim=64,
64+
v_head_dim=128,
65+
mscale=0.70,
66+
),
67+
"236B": DeepSeekV3ModelArgs(
68+
vocab_size=102400,
69+
dim=5120,
70+
inter_dim=12288,
71+
moe_inter_dim=1536,
72+
n_layers=60,
73+
n_dense_layers=1,
74+
n_heads=128,
75+
n_routed_experts=160,
76+
n_shared_experts=2,
77+
n_activated_experts=6,
78+
n_expert_groups=8,
79+
n_limited_groups=3,
80+
route_scale=16.0,
81+
q_lora_rank=1536,
82+
kv_lora_rank=512,
83+
qk_nope_head_dim=128,
84+
qk_rope_head_dim=64,
85+
v_head_dim=128,
86+
),
87+
"671B": DeepSeekV3ModelArgs(
88+
vocab_size=129280,
89+
dim=7168,
90+
inter_dim=18432,
91+
moe_inter_dim=2048,
92+
n_layers=61,
93+
n_dense_layers=3,
94+
n_heads=128,
95+
n_routed_experts=256,
96+
n_shared_experts=1,
97+
n_activated_experts=8,
98+
n_expert_groups=8,
99+
n_limited_groups=4,
100+
route_scale=2.5,
101+
score_func="sigmoid",
102+
q_lora_rank=1536,
103+
kv_lora_rank=512,
104+
qk_nope_head_dim=128,
105+
qk_rope_head_dim=64,
106+
v_head_dim=128,
107+
dtype="fp8",
108+
),
109+
}
110+
111+
112+
register_train_spec(
113+
TrainSpec(
114+
name="deepseek_v3",
115+
cls=DeepSeekV3Model,
116+
config=deepseekv3_configs,
117+
parallelize_fn=parallelize_deepseekv3,
118+
pipelining_fn=None,
119+
build_optimizers_fn=build_optimizers,
120+
build_lr_schedulers_fn=build_lr_schedulers,
121+
build_dataloader_fn=build_hf_dataloader,
122+
build_tokenizer_fn=build_tiktoken_tokenizer,
123+
build_loss_fn=build_cross_entropy_loss,
124+
)
125+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch.nn as nn
9+
10+
from torch.distributed.device_mesh import DeviceMesh
11+
12+
from torchtitan.config_manager import JobConfig
13+
from torchtitan.distributed import ParallelDims
14+
15+
16+
def parallelize_deepseekv3(
17+
model: nn.Module,
18+
world_mesh: DeviceMesh,
19+
parallel_dims: ParallelDims,
20+
job_config: JobConfig,
21+
):
22+
# TODO: Add support for parallelizing the model, this is a placeholder function for now
23+
return model

torchtitan/models/deepseek-v3/model/args.py renamed to torchtitan/models/deepseek_v3/model/args.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from torchtitan.components.tokenizer import Tokenizer
1616
from torchtitan.config_manager import JobConfig
1717
from torchtitan.protocols.train_spec import BaseModelArgs
18+
from torchtitan.tools.logging import logger
1819

1920

2021
# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
2122
@dataclass
22-
class DeepseekV3ModelArgs(BaseModelArgs):
23+
class DeepSeekV3ModelArgs(BaseModelArgs):
2324
"""
2425
Data class for defining model arguments and hyperparameters.
2526
@@ -53,7 +54,6 @@ class DeepseekV3ModelArgs(BaseModelArgs):
5354
rope_factor (float): Scaling factor for extended sequence lengths.
5455
beta_fast (int): Fast beta correction factor.
5556
beta_slow (int): Slow beta correction factor.
56-
mscale (float): Scaling factor for extended attention.
5757
"""
5858

5959
max_batch_size: int = 8
@@ -95,12 +95,63 @@ class DeepseekV3ModelArgs(BaseModelArgs):
9595

9696
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
9797
"""
98-
TODO: Placeholder for now
98+
Update the model_config config from the given job config.
9999
"""
100-
pass
100+
self.vocab_size = tokenizer.n_words
101+
self.max_seq_len = job_config.training.seq_len
101102

102103
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
103104
"""
104-
TODO: Placeholder for now
105+
Adopted from llama4 implementation.
105106
"""
106-
return 0, 0
107+
nparams_embedding = 0
108+
nparams_moe_router = 0
109+
nparams_shared_expert = 0
110+
nparams_experts = 0
111+
nparams_dense = 0
112+
113+
for name, p in model.named_parameters():
114+
print(name)
115+
if "embedding" in name:
116+
nparams_embedding += p.numel()
117+
nparams_dense += p.numel()
118+
elif "moe.shared_expert" in name:
119+
nparams_shared_expert += p.numel()
120+
elif "moe.router" in name:
121+
nparams_moe_router += p.numel()
122+
elif "moe.experts" in name:
123+
nparams_experts += p.numel()
124+
else:
125+
nparams_dense += p.numel()
126+
127+
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
128+
nparams = nparams_dense + nparams_sparse
129+
nparams_sparse_active = (
130+
nparams_moe_router
131+
+ nparams_shared_expert
132+
+ nparams_experts * self.n_activated_experts // self.n_routed_experts
133+
)
134+
135+
logger.info(
136+
f"Total parameter count: dense {nparams_dense:,}, "
137+
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
138+
)
139+
140+
l, h, q, t = (
141+
self.n_layers,
142+
self.n_heads,
143+
self.dim // self.n_heads,
144+
seq_len,
145+
)
146+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
147+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
148+
# 2. the flash attention does 1 more matmul recomputation in the backward
149+
# but recomputation should not be counted in calculating MFU (+0)
150+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
151+
# 4. we follow the convention and do not account for sparsity in causal attention
152+
num_flops_per_token = (
153+
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
154+
+ 12 * l * h * q * t
155+
)
156+
157+
return nparams, num_flops_per_token

torchtitan/models/deepseek-v3/model/model.py renamed to torchtitan/models/deepseek_v3/model/model.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
from torchtitan.models.attention import build_attention
1414
from torchtitan.protocols.train_spec import ModelProtocol
1515

16-
from .args import DeepseekV3ModelArgs
16+
from .args import DeepSeekV3ModelArgs
1717
from .moe import MoE
1818

1919

20-
# Adopted from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
21-
def precompute_freqs_cis(args: DeepseekV3ModelArgs) -> torch.Tensor:
20+
# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
21+
def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor:
2222
"""
2323
Precomputes frequency-based complex exponential values for rotary positional embeddings.
2424
2525
Args:
26-
args (DeepseekV3ModelArgs): Model arguments containing positional embedding parameters.
26+
args (DeepSeekV3ModelArgs): Model arguments containing positional embedding parameters.
2727
2828
Returns:
2929
torch.Tensor: Precomputed complex exponential values for positional embeddings.
@@ -98,16 +98,21 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor:
9898
# Basic RoPE frequency calculation
9999
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
100100

101-
# YaRN scaling for extended context
101+
# YaRN scaling for extended context. YaRN is used to extend the context length after pre-training.
102102
if seqlen > args.original_seq_len:
103103
low, high = find_correction_range(
104104
beta_fast, beta_slow, dim, base, args.original_seq_len
105105
)
106106
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
107107
freqs = freqs / factor * (1 - smooth) + freqs * smooth
108108

109+
# Create position indices
109110
t = torch.arange(seqlen)
111+
112+
# Outer product: [positions] × [frequencies]
110113
freqs = torch.outer(t, freqs)
114+
115+
# Convert to complex exponentials: e^(i*freq*pos)
111116
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
112117
return freqs_cis
113118

@@ -135,7 +140,7 @@ class Attention(nn.Module):
135140
Multi-head attention (MLA) module.
136141
"""
137142

138-
def __init__(self, model_args: DeepseekV3ModelArgs):
143+
def __init__(self, model_args: DeepSeekV3ModelArgs):
139144
super().__init__()
140145
self.dim = model_args.dim
141146
self.n_heads = model_args.n_heads
@@ -264,13 +269,13 @@ class TransformerBlock(nn.Module):
264269
Transformer block with attention and feed-forward layers.
265270
"""
266271

267-
def __init__(self, layer_id: int, model_args: DeepseekV3ModelArgs):
272+
def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
268273

269274
super().__init__()
270275
self.attention = Attention(model_args)
271276
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
272-
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
273-
self.ffn = (
277+
self.moe_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
278+
self.moe = (
274279
FeedForward(model_args.dim, model_args.inter_dim)
275280
if layer_id < model_args.n_dense_layers
276281
else MoE(model_args)
@@ -288,16 +293,16 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
288293
torch.Tensor: Output tensor with the same shape as the input.
289294
"""
290295
x = x + self.attention(self.attention_norm(x), freqs_cis)
291-
x = x + self.ffn(self.ffn_norm(x))
296+
x = x + self.moe(self.moe_norm(x))
292297
return x
293298

294299

295-
class Transformer(nn.Module, ModelProtocol):
300+
class DeepSeekV3Model(nn.Module, ModelProtocol):
296301
"""
297-
Deepseek-V3 Transformer model with attention and feed-forward layers.
302+
DeepSeek-V3 Transformer model with attention and feed-forward layers.
298303
"""
299304

300-
def __init__(self, model_args: DeepseekV3ModelArgs):
305+
def __init__(self, model_args: DeepSeekV3ModelArgs):
301306
super().__init__()
302307
self.max_seq_len = model_args.max_seq_len
303308
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
@@ -327,10 +332,11 @@ def forward(self, tokens: torch.Tensor):
327332
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
328333
"""
329334
h = self.tok_embeddings(tokens)
335+
330336
for layer in self.layers:
331337
h = layer(h, self.freqs_cis)
332-
h = self.norm(h)[:, -1]
333-
output = self.output(h)
338+
h = self.norm(h)
339+
output = self.output(h) # (batch_size, seq_len, dim)
334340
return output
335341

336342
def init_weights(self, buffer_device: torch.device | None = None) -> None:

0 commit comments

Comments
 (0)