Skip to content

Commit 2a4bc36

Browse files
committed
models
1 parent d6bd731 commit 2a4bc36

File tree

1 file changed

+119
-1
lines changed
  • torchtitan/models/deepseek-v3/model

1 file changed

+119
-1
lines changed

torchtitan/models/deepseek-v3/model/model.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import math
88

99
import torch
10+
import torch.nn.functional as F
1011
from torch import nn
11-
1212
from torchtitan.models.attention import build_attention
13+
from torchtitan.models.llama3 import model
14+
from torchtitan.protocols.train_spec import ModelProtocol
1315

1416
from .args import DeepseekV3ModelArgs
1517

@@ -88,7 +90,10 @@ def linear_ramp_factor(min, max, dim):
8890
ramp_func = torch.clamp(linear_func, 0, 1)
8991
return ramp_func
9092

93+
# Basic RoPE frequency calculation
9194
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
95+
96+
# YaRN scaling for extended context
9297
if seqlen > args.original_seq_len:
9398
low, high = find_correction_range(
9499
beta_fast, beta_slow, dim, base, args.original_seq_len
@@ -222,3 +227,116 @@ def forward(
222227
).contiguous() # (bs, seqlen, n_heads, v_head_dim)
223228
output = output.view(bsz, seqlen, -1) # (bs, seqlen, n_heads * v_head_dim)
224229
return self.wo(output) # (bsz, seqlen, dim)
230+
231+
232+
class FeedForward(nn.Module):
233+
"""
234+
FeedForward module
235+
236+
Args:
237+
dim (int): Input dimension.
238+
hidden_dim (int): Hidden dimension of the feedforward layer.
239+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
240+
ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
241+
242+
Attributes:
243+
w1 (Linear): Linear transformation for the first layer.
244+
w2 (Linear): Linear transformation for the second layer.
245+
w3 (Linear): Linear transformation for the third layer.
246+
247+
"""
248+
249+
def __init__(
250+
self,
251+
dim: int,
252+
hidden_dim: int,
253+
):
254+
super().__init__()
255+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
256+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
257+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
258+
259+
def forward(self, x):
260+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
261+
262+
def init_weights(self, init_std: float):
263+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
264+
for linear in (self.w2, self.w3):
265+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
266+
267+
268+
class TransformerBlock(nn.Module):
269+
"""
270+
Transformer block with attention and feed-forward layers.
271+
"""
272+
273+
def __init__(self, model_args: DeepseekV3ModelArgs):
274+
super().__init__()
275+
self.attention = Attention(model_args)
276+
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
277+
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
278+
self.ffn = FeedForward(model_args.dim, model_args.inter_dim)
279+
280+
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
281+
"""
282+
Forward pass for the Transformer block.
283+
284+
Args:
285+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
286+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
287+
288+
Returns:
289+
torch.Tensor: Output tensor with the same shape as the input.
290+
"""
291+
x = x + self.attention(self.attention_norm(x), freqs_cis)
292+
x = x + self.ffn(self.ffn_norm(x))
293+
return x
294+
295+
296+
class Transformer(nn.Module, ModelProtocol):
297+
"""
298+
Deepseek-V3 Transformer model with attention and feed-forward layers.
299+
"""
300+
301+
def __init__(self, model_args: DeepseekV3ModelArgs):
302+
super().__init__()
303+
self.max_seq_len = model_args.max_seq_len
304+
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
305+
self.register_buffer(
306+
"freqs_cis", precompute_freqs_cis(model_args), persistent=False
307+
)
308+
309+
self.layers = torch.nn.ModuleList()
310+
for layer_id in range(model_args.n_layers):
311+
self.layers.append(TransformerBlock(model_args))
312+
self.norm = nn.RMSNorm(model_args.dim)
313+
self.output = nn.Linear(
314+
model_args.dim, model_args.vocab_size, dtype=torch.get_default_dtype()
315+
)
316+
self.init_weights()
317+
318+
def forward(self, tokens: torch.Tensor):
319+
"""
320+
Forward pass for the Transformer model.
321+
322+
Args:
323+
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
324+
start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
325+
326+
Returns:
327+
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
328+
"""
329+
h = self.tok_embeddings(tokens)
330+
# # This is casual mask, which is already handled in sdpa
331+
# if seqlen > 1:
332+
# mask = torch.full(
333+
# (seqlen, seqlen), float("-inf"), device=tokens.device
334+
# ).triu_(1)
335+
for layer in self.layers:
336+
h = layer(h, self.freqs_cis)
337+
h = self.norm(h)[:, -1]
338+
output = self.output(h)
339+
return output
340+
341+
def init_weights(self, buffer_device: torch.device | None = None) -> None:
342+
pass

0 commit comments

Comments
 (0)