Skip to content

[WIP] Enable causal block mask for sdpa #1348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 157 additions & 17 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import torch.nn.functional as F
from torch.nested import nested_tensor_from_jagged
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
Expand All @@ -20,10 +21,10 @@

from torchtitan.tools.utils import has_cuda_capability

# FlexAttention mask type. For each mask type, we initialize it at most once per
# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to
# Attention mask type. For each mask type, we initialize it at most once per
# batch. To record what it is initialized, ATTN_MASK_T is used as the key to
# track the initialized mask.
FLEX_ATTN_MASK_T = tuple[str, int | None]
ATTN_MASK_T = tuple[str, int | None]


class FlexAttention(torch.nn.Module):
Expand All @@ -50,13 +51,13 @@ class FlexAttention(torch.nn.Module):
flex_attention, mode="max-autotune-no-cudagraphs"
)
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set()
used_attn_mask_types: ClassVar[set[ATTN_MASK_T]] = set()
# Attention mask type to the created BlockMask.
# This allows us to keep track the created block masks for each
# new batch. We will use this to update the block mask when a
# new batch is created. This also allows user to create different
# block masks for different layers.
block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {}
block_masks: ClassVar[dict[ATTN_MASK_T, BlockMask]] = {}

# Instance variables.
attn_mask_type: str
Expand All @@ -73,7 +74,7 @@ def __init__(
FlexAttention.used_attn_mask_types.add(self.mask_key)

@property
def mask_key(self) -> FLEX_ATTN_MASK_T:
def mask_key(self) -> ATTN_MASK_T:
return (self.attn_mask_type, self.fixed_block_size)

def forward(
Expand Down Expand Up @@ -183,15 +184,20 @@ def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None:
class ScaledDotProductAttention(torch.nn.Module):
backends: ClassVar[list[SDPBackend]] = []

# Offsets between the packed sequences in the batch used to create nested tensors.
offsets: ClassVar[torch.Tensor | None] = None

used_attn_mask_types: ClassVar[set[ATTN_MASK_T]] = set()

def __init__(self, attn_mask_type: str) -> None:
super().__init__()
if attn_mask_type != "causal":
raise ValueError(
"TorchTitan with SDPA currently only supports causal mask."
)

ScaledDotProductAttention._init_backend()

self.attn_mask_type = attn_mask_type

ScaledDotProductAttention.used_attn_mask_types.add(self.mask_key)

@classmethod
def _init_backend(cls) -> None:
if cls.backends:
Expand All @@ -206,12 +212,145 @@ def _init_backend(cls) -> None:
if has_cuda_capability(10, 0):
cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION)

@property
def mask_key(self) -> ATTN_MASK_T:
return (self.attn_mask_type, None)

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
assert self.backends, "SDPA Backends should not be empty."
with sdpa_kernel(self.backends, set_priority=True):
return F.scaled_dot_product_attention(q, k, v, is_causal=True)

if ScaledDotProductAttention.offsets is None:
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
original_shape = q.shape
if q.size(0) == 1:
# Create nested tensor: [1, h, s, d] -> [num_samples, h, j1, d]
q_nested = nested_tensor_from_jagged(
q.view(q.shape[1:]),
ScaledDotProductAttention.offsets,
jagged_dim=2,
)
k_nested = nested_tensor_from_jagged(
k.view(k.shape[1:]),
ScaledDotProductAttention.offsets,
jagged_dim=2,
)
v_nested = nested_tensor_from_jagged(
v.view(v.shape[1:]),
ScaledDotProductAttention.offsets,
jagged_dim=2,
)

act_nested = F.scaled_dot_product_attention(
q_nested, k_nested, v_nested, is_causal=True
)

return act_nested.values().view(original_shape)
else:
# Flatten the packed samples along dim 2: [bs, h, s, d] -> [1, h, bs*s, d]
q_packed = (
q.permute(0, 2, 1, 3)
.reshape(1, -1, q.shape[1], q.shape[3])
.permute(0, 2, 1, 3)
)
del q
# Create nested tensor: [1, h, bs*s, d] -> [num_samples, h, j1, d]
q_nested = nested_tensor_from_jagged(
q_packed.view(q_packed.shape[1:]),
ScaledDotProductAttention.offsets,
jagged_dim=2,
)

k_packed = (
k.permute(0, 2, 1, 3)
.reshape(1, -1, k.shape[1], k.shape[3])
.permute(0, 2, 1, 3)
)
del k
k_nested = nested_tensor_from_jagged(
k_packed.view(k_packed.shape[1:]),
ScaledDotProductAttention.offsets,
jagged_dim=2,
)

v_packed = (
v.permute(0, 2, 1, 3)
.reshape(1, -1, v.shape[1], v.shape[3])
.permute(0, 2, 1, 3)
)
del v
v_nested = nested_tensor_from_jagged(
v_packed.view(v_packed.shape[1:]),
ScaledDotProductAttention.offsets,
jagged_dim=2,
)

act_nested = F.scaled_dot_product_attention(
q_nested, k_nested, v_nested, is_causal=True
)

# Repack samples along dim 2 and restore original shape: [num_samples, h, j1, d] -> [bs, h, s, d]
return (
act_nested.values()
.unsqueeze(0)
.permute(0, 2, 1, 3)
.reshape(
original_shape[0],
-1,
act_nested.shape[1],
act_nested.shape[3],
)
.permute(0, 2, 1, 3)
)

@staticmethod
@torch.no_grad()
def _get_offsets(batch: torch.Tensor, eos_id: int) -> torch.Tensor:
# Determine packed sequence boundaries.
mask = batch == eos_id

indices = mask.flatten().nonzero().flatten()

# In case the last token is not EOS, we need to add an extra element to the indices.
if indices.numel() == 0 or indices[-1] != batch.numel() - 1:
addition_elements = 2
else:
addition_elements = 1

# Store the offsets between the packed sequences in the batch.
offsets = torch.empty(
(indices.size(0) + addition_elements),
dtype=indices.dtype,
device=batch.device,
)
offsets[0] = 0
offsets[1 : indices.size(0) + 1] = indices.flatten() + 1
offsets[-1] = batch.numel()

return offsets

@staticmethod
@torch.no_grad()
def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None:

for mask_key in ScaledDotProductAttention.used_attn_mask_types:
attn_mask_type, _ = mask_key
match attn_mask_type:
case "causal":
return
case "block_causal":
if eos_id is None:
raise RuntimeError(
"eos_id must be provided for block_causal mask."
)
ScaledDotProductAttention.offsets = (
ScaledDotProductAttention._get_offsets(batch, eos_id)
)
case _:
raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}")


def build_attention(
Expand All @@ -224,12 +363,13 @@ def build_attention(
raise ValueError(
"TorchTitan with SDPA currently does not support fixed_block_size."
)
if attn_mask_type != "causal":
raise ValueError(
"TorchTitan with SDPA currently only supports causal mask."
)
return ScaledDotProductAttention(attn_mask_type)


def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None:
FlexAttention.init_attention_mask(batch, eos_id)
def init_attention_mask(
batch: torch.Tensor, eos_id: int | None = None, use_flex_attn: bool = True
) -> None:
if use_flex_attn:
FlexAttention.init_attention_mask(batch, eos_id)
else:
ScaledDotProductAttention.init_attention_mask(batch, eos_id)
7 changes: 7 additions & 0 deletions torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@
use_flex_attn=True,
attn_mask_type="block_causal",
),
"debugmodel_sdpa_block_causal": TransformerModelArgs(
dim=256,
n_layers=6,
n_heads=16,
rope_theta=500000,
attn_mask_type="block_causal",
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
Expand Down
9 changes: 5 additions & 4 deletions torchtitan/models/llama3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,11 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
torch.Tensor: Output logits after applying the Transformer model.

"""
if self.model_args.use_flex_attn:
init_attention_mask(
input_batch if input_batch is not None else tokens, eos_id=self.eos_id
)
init_attention_mask(
input_batch if input_batch is not None else tokens,
eos_id=self.eos_id,
use_flex_attn=self.model_args.use_flex_attn,
)

# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
Expand Down
Loading