diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f7f13462e..3f11f1ab4 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F +from torch.nested import nested_tensor_from_jagged from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, @@ -20,10 +21,10 @@ from torchtitan.tools.utils import has_cuda_capability -# FlexAttention mask type. For each mask type, we initialize it at most once per -# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to +# Attention mask type. For each mask type, we initialize it at most once per +# batch. To record what it is initialized, ATTN_MASK_T is used as the key to # track the initialized mask. -FLEX_ATTN_MASK_T = tuple[str, int | None] +ATTN_MASK_T = tuple[str, int | None] class FlexAttention(torch.nn.Module): @@ -50,13 +51,13 @@ class FlexAttention(torch.nn.Module): flex_attention, mode="max-autotune-no-cudagraphs" ) compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) - used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set() + used_attn_mask_types: ClassVar[set[ATTN_MASK_T]] = set() # Attention mask type to the created BlockMask. # This allows us to keep track the created block masks for each # new batch. We will use this to update the block mask when a # new batch is created. This also allows user to create different # block masks for different layers. - block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {} + block_masks: ClassVar[dict[ATTN_MASK_T, BlockMask]] = {} # Instance variables. attn_mask_type: str @@ -73,7 +74,7 @@ def __init__( FlexAttention.used_attn_mask_types.add(self.mask_key) @property - def mask_key(self) -> FLEX_ATTN_MASK_T: + def mask_key(self) -> ATTN_MASK_T: return (self.attn_mask_type, self.fixed_block_size) def forward( @@ -183,15 +184,20 @@ def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None: class ScaledDotProductAttention(torch.nn.Module): backends: ClassVar[list[SDPBackend]] = [] + # Offsets between the packed sequences in the batch used to create nested tensors. + offsets: ClassVar[torch.Tensor | None] = None + + used_attn_mask_types: ClassVar[set[ATTN_MASK_T]] = set() + def __init__(self, attn_mask_type: str) -> None: super().__init__() - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) ScaledDotProductAttention._init_backend() + self.attn_mask_type = attn_mask_type + + ScaledDotProductAttention.used_attn_mask_types.add(self.mask_key) + @classmethod def _init_backend(cls) -> None: if cls.backends: @@ -206,12 +212,145 @@ def _init_backend(cls) -> None: if has_cuda_capability(10, 0): cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + @property + def mask_key(self) -> ATTN_MASK_T: + return (self.attn_mask_type, None) + def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: assert self.backends, "SDPA Backends should not be empty." with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True) + + if ScaledDotProductAttention.offsets is None: + return F.scaled_dot_product_attention(q, k, v, is_causal=True) + else: + original_shape = q.shape + if q.size(0) == 1: + # Create nested tensor: [1, h, s, d] -> [num_samples, h, j1, d] + q_nested = nested_tensor_from_jagged( + q.view(q.shape[1:]), + ScaledDotProductAttention.offsets, + jagged_dim=2, + ) + k_nested = nested_tensor_from_jagged( + k.view(k.shape[1:]), + ScaledDotProductAttention.offsets, + jagged_dim=2, + ) + v_nested = nested_tensor_from_jagged( + v.view(v.shape[1:]), + ScaledDotProductAttention.offsets, + jagged_dim=2, + ) + + act_nested = F.scaled_dot_product_attention( + q_nested, k_nested, v_nested, is_causal=True + ) + + return act_nested.values().view(original_shape) + else: + # Flatten the packed samples along dim 2: [bs, h, s, d] -> [1, h, bs*s, d] + q_packed = ( + q.permute(0, 2, 1, 3) + .reshape(1, -1, q.shape[1], q.shape[3]) + .permute(0, 2, 1, 3) + ) + del q + # Create nested tensor: [1, h, bs*s, d] -> [num_samples, h, j1, d] + q_nested = nested_tensor_from_jagged( + q_packed.view(q_packed.shape[1:]), + ScaledDotProductAttention.offsets, + jagged_dim=2, + ) + + k_packed = ( + k.permute(0, 2, 1, 3) + .reshape(1, -1, k.shape[1], k.shape[3]) + .permute(0, 2, 1, 3) + ) + del k + k_nested = nested_tensor_from_jagged( + k_packed.view(k_packed.shape[1:]), + ScaledDotProductAttention.offsets, + jagged_dim=2, + ) + + v_packed = ( + v.permute(0, 2, 1, 3) + .reshape(1, -1, v.shape[1], v.shape[3]) + .permute(0, 2, 1, 3) + ) + del v + v_nested = nested_tensor_from_jagged( + v_packed.view(v_packed.shape[1:]), + ScaledDotProductAttention.offsets, + jagged_dim=2, + ) + + act_nested = F.scaled_dot_product_attention( + q_nested, k_nested, v_nested, is_causal=True + ) + + # Repack samples along dim 2 and restore original shape: [num_samples, h, j1, d] -> [bs, h, s, d] + return ( + act_nested.values() + .unsqueeze(0) + .permute(0, 2, 1, 3) + .reshape( + original_shape[0], + -1, + act_nested.shape[1], + act_nested.shape[3], + ) + .permute(0, 2, 1, 3) + ) + + @staticmethod + @torch.no_grad() + def _get_offsets(batch: torch.Tensor, eos_id: int) -> torch.Tensor: + # Determine packed sequence boundaries. + mask = batch == eos_id + + indices = mask.flatten().nonzero().flatten() + + # In case the last token is not EOS, we need to add an extra element to the indices. + if indices.numel() == 0 or indices[-1] != batch.numel() - 1: + addition_elements = 2 + else: + addition_elements = 1 + + # Store the offsets between the packed sequences in the batch. + offsets = torch.empty( + (indices.size(0) + addition_elements), + dtype=indices.dtype, + device=batch.device, + ) + offsets[0] = 0 + offsets[1 : indices.size(0) + 1] = indices.flatten() + 1 + offsets[-1] = batch.numel() + + return offsets + + @staticmethod + @torch.no_grad() + def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None: + + for mask_key in ScaledDotProductAttention.used_attn_mask_types: + attn_mask_type, _ = mask_key + match attn_mask_type: + case "causal": + return + case "block_causal": + if eos_id is None: + raise RuntimeError( + "eos_id must be provided for block_causal mask." + ) + ScaledDotProductAttention.offsets = ( + ScaledDotProductAttention._get_offsets(batch, eos_id) + ) + case _: + raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") def build_attention( @@ -224,12 +363,13 @@ def build_attention( raise ValueError( "TorchTitan with SDPA currently does not support fixed_block_size." ) - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) return ScaledDotProductAttention(attn_mask_type) -def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None: - FlexAttention.init_attention_mask(batch, eos_id) +def init_attention_mask( + batch: torch.Tensor, eos_id: int | None = None, use_flex_attn: bool = True +) -> None: + if use_flex_attn: + FlexAttention.init_attention_mask(batch, eos_id) + else: + ScaledDotProductAttention.init_attention_mask(batch, eos_id) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 7cfc03b4a..1d92eea40 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -39,6 +39,13 @@ use_flex_attn=True, attn_mask_type="block_causal", ), + "debugmodel_sdpa_block_causal": TransformerModelArgs( + dim=256, + n_layers=6, + n_heads=16, + rope_theta=500000, + attn_mask_type="block_causal", + ), "8B": TransformerModelArgs( dim=4096, n_layers=32, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index a52820939..6bad6888c 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -416,10 +416,11 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None) torch.Tensor: Output logits after applying the Transformer model. """ - if self.model_args.use_flex_attn: - init_attention_mask( - input_batch if input_batch is not None else tokens, eos_id=self.eos_id - ) + init_attention_mask( + input_batch if input_batch is not None else tokens, + eos_id=self.eos_id, + use_flex_attn=self.model_args.use_flex_attn, + ) # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens