|
7 | 7 | import math
|
8 | 8 |
|
9 | 9 | import torch
|
| 10 | +import torch.nn.functional as F |
10 | 11 | from torch import nn
|
11 |
| - |
12 | 12 | from torchtitan.models.attention import build_attention
|
| 13 | +from torchtitan.models.llama3 import model |
| 14 | +from torchtitan.protocols.train_spec import ModelProtocol |
13 | 15 |
|
14 | 16 | from .args import DeepseekV3ModelArgs
|
15 | 17 |
|
@@ -88,7 +90,10 @@ def linear_ramp_factor(min, max, dim):
|
88 | 90 | ramp_func = torch.clamp(linear_func, 0, 1)
|
89 | 91 | return ramp_func
|
90 | 92 |
|
| 93 | + # Basic RoPE frequency calculation |
91 | 94 | freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 95 | + |
| 96 | + # YaRN scaling for extended context |
92 | 97 | if seqlen > args.original_seq_len:
|
93 | 98 | low, high = find_correction_range(
|
94 | 99 | beta_fast, beta_slow, dim, base, args.original_seq_len
|
@@ -222,3 +227,116 @@ def forward(
|
222 | 227 | ).contiguous() # (bs, seqlen, n_heads, v_head_dim)
|
223 | 228 | output = output.view(bsz, seqlen, -1) # (bs, seqlen, n_heads * v_head_dim)
|
224 | 229 | return self.wo(output) # (bsz, seqlen, dim)
|
| 230 | + |
| 231 | + |
| 232 | +class FeedForward(nn.Module): |
| 233 | + """ |
| 234 | + FeedForward module |
| 235 | +
|
| 236 | + Args: |
| 237 | + dim (int): Input dimension. |
| 238 | + hidden_dim (int): Hidden dimension of the feedforward layer. |
| 239 | + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
| 240 | + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. |
| 241 | +
|
| 242 | + Attributes: |
| 243 | + w1 (Linear): Linear transformation for the first layer. |
| 244 | + w2 (Linear): Linear transformation for the second layer. |
| 245 | + w3 (Linear): Linear transformation for the third layer. |
| 246 | +
|
| 247 | + """ |
| 248 | + |
| 249 | + def __init__( |
| 250 | + self, |
| 251 | + dim: int, |
| 252 | + hidden_dim: int, |
| 253 | + ): |
| 254 | + super().__init__() |
| 255 | + self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| 256 | + self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| 257 | + self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| 258 | + |
| 259 | + def forward(self, x): |
| 260 | + return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 261 | + |
| 262 | + def init_weights(self, init_std: float): |
| 263 | + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) |
| 264 | + for linear in (self.w2, self.w3): |
| 265 | + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) |
| 266 | + |
| 267 | + |
| 268 | +class TransformerBlock(nn.Module): |
| 269 | + """ |
| 270 | + Transformer block with attention and feed-forward layers. |
| 271 | + """ |
| 272 | + |
| 273 | + def __init__(self, model_args: DeepseekV3ModelArgs): |
| 274 | + super().__init__() |
| 275 | + self.attention = Attention(model_args) |
| 276 | + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) |
| 277 | + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) |
| 278 | + self.ffn = FeedForward(model_args.dim, model_args.inter_dim) |
| 279 | + |
| 280 | + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): |
| 281 | + """ |
| 282 | + Forward pass for the Transformer block. |
| 283 | +
|
| 284 | + Args: |
| 285 | + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). |
| 286 | + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. |
| 287 | +
|
| 288 | + Returns: |
| 289 | + torch.Tensor: Output tensor with the same shape as the input. |
| 290 | + """ |
| 291 | + x = x + self.attention(self.attention_norm(x), freqs_cis) |
| 292 | + x = x + self.ffn(self.ffn_norm(x)) |
| 293 | + return x |
| 294 | + |
| 295 | + |
| 296 | +class Transformer(nn.Module, ModelProtocol): |
| 297 | + """ |
| 298 | + Deepseek-V3 Transformer model with attention and feed-forward layers. |
| 299 | + """ |
| 300 | + |
| 301 | + def __init__(self, model_args: DeepseekV3ModelArgs): |
| 302 | + super().__init__() |
| 303 | + self.max_seq_len = model_args.max_seq_len |
| 304 | + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) |
| 305 | + self.register_buffer( |
| 306 | + "freqs_cis", precompute_freqs_cis(model_args), persistent=False |
| 307 | + ) |
| 308 | + |
| 309 | + self.layers = torch.nn.ModuleList() |
| 310 | + for layer_id in range(model_args.n_layers): |
| 311 | + self.layers.append(TransformerBlock(model_args)) |
| 312 | + self.norm = nn.RMSNorm(model_args.dim) |
| 313 | + self.output = nn.Linear( |
| 314 | + model_args.dim, model_args.vocab_size, dtype=torch.get_default_dtype() |
| 315 | + ) |
| 316 | + self.init_weights() |
| 317 | + |
| 318 | + def forward(self, tokens: torch.Tensor): |
| 319 | + """ |
| 320 | + Forward pass for the Transformer model. |
| 321 | +
|
| 322 | + Args: |
| 323 | + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). |
| 324 | + start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0. |
| 325 | +
|
| 326 | + Returns: |
| 327 | + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). |
| 328 | + """ |
| 329 | + h = self.tok_embeddings(tokens) |
| 330 | + # # This is casual mask, which is already handled in sdpa |
| 331 | + # if seqlen > 1: |
| 332 | + # mask = torch.full( |
| 333 | + # (seqlen, seqlen), float("-inf"), device=tokens.device |
| 334 | + # ).triu_(1) |
| 335 | + for layer in self.layers: |
| 336 | + h = layer(h, self.freqs_cis) |
| 337 | + h = self.norm(h)[:, -1] |
| 338 | + output = self.output(h) |
| 339 | + return output |
| 340 | + |
| 341 | + def init_weights(self, buffer_device: torch.device | None = None) -> None: |
| 342 | + pass |
0 commit comments