From ab25ee0ac1cbf3dcd9ee212a8b4e0f0c17c3f413 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 19 Apr 2025 17:50:42 +0200 Subject: [PATCH 1/5] initial attention dispatcher support supported: flash, flash_varlen, flex, native, sage, sage_varlen, xformers --- src/diffusers/models/attention_dispatch.py | 910 ++++++++++++++++++ .../transformers/transformer_lumina2.py | 5 +- src/diffusers/utils/__init__.py | 5 + src/diffusers/utils/constants.py | 2 + src/diffusers/utils/import_utils.py | 55 ++ 5 files changed, 974 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/models/attention_dispatch.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py new file mode 100644 index 000000000000..be6c15509ce4 --- /dev/null +++ b/src/diffusers/models/attention_dispatch.py @@ -0,0 +1,910 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import inspect +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch + +from ..utils import ( + OptionalDependencyNotAvailable, + get_logger, + is_flash_attn_available, + is_flash_attn_version, + is_sageattention_available, + is_sageattention_version, + is_torch_version, + is_xformers_available, + is_xformers_version, +) +from ..utils.constants import DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ATTN_PROVIDER + + +if is_flash_attn_available(): + if is_flash_attn_version("<", "2.6.3"): + raise OptionalDependencyNotAvailable( + "The `flash-attn` library version is too old. Please update it to at least 2.6.3." + ) + + from flash_attn import flash_attn_func, flash_attn_varlen_func +else: + flash_attn_func = None + flash_attn_varlen_func = None + + +if is_sageattention_available(): + if is_sageattention_version("<", "2.1.1"): + raise OptionalDependencyNotAvailable( + "The `sageattention` library version is too old. Please update it to at least 2.1.1." + ) + + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, + sageattn_varlen, + ) +else: + sageattn = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_varlen = None + + +if is_torch_version(">=", "2.5.0"): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + from torch.nn.attention.flex_attention import flex_attention as torch_flex_attention +else: + create_block_mask = None + torch_flex_attention = None + + class BlockMask: + def __init__(self, *args, **kwargs): + raise OptionalDependencyNotAvailable( + "The `torch` library version is too old. Please update it to at least 2.5.0." + ) + + +if is_xformers_available(): + if is_xformers_version("<", "0.0.29"): + raise OptionalDependencyNotAvailable( + "The `xformers` library version is too old. Please update it to at least 0.0.29." + ) + + import xformers.ops as xops +else: + xops = None + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_IS_CREATE_BLOCK_MASK_COMPILED = False +_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] +_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] +_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] + + +class AttentionProvider(str, Enum): + # EAGER = "eager" + + # `flash-attn` + FLASH = "flash" + FLASH_VARLEN = "flash_varlen" + + # PyTorch native + FLEX = "flex" + NATIVE = "native" + _NATIVE_CUDNN = "_native_cudnn" + _NATIVE_EFFICIENT = "_native_efficient" + _NATIVE_FLASH = "_native_flash" + _NATIVE_MATH = "_native_math" + + # `sageattention` + SAGE = "sage" + SAGE_VARLEN = "sage_varlen" + _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" + _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" + _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" + _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # SPARGE = "sparge" + + # `xformers` + XFORMERS = "xformers" + + +class _AttentionProviderRegistry: + _providers = {} + _constraints = {} + _supported_arg_names = {} + _active_provider = AttentionProvider(DIFFUSERS_ATTN_PROVIDER) + _checks_enabled = DIFFUSERS_ATTN_CHECKS + + @classmethod + def register(cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None): + logger.debug(f"Registering attention provider: {provider} with constraints: {constraints}") + + def decorator(func): + cls._providers[provider] = func + cls._constraints[provider] = constraints or [] + cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) + return func + + return decorator + + @classmethod + def get_active_provider(cls): + return cls._active_provider, cls._providers[cls._active_provider] + + @classmethod + def list_providers(cls): + return list(cls._providers.keys()) + + +def attention_dispatch( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() + kwargs = { + "query": query, + "key": key, + "value": value, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + **attention_kwargs, + } + + if _AttentionProviderRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) + if removed_kwargs: + logger.warning(f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}.") + for check in _AttentionProviderRegistry._constraints.get(provider_name): + check(**kwargs) + + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} + return provider_fn(**kwargs) + + +def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: + if attn_mask is not None: + raise ValueError("Attention mask must be None for this provider.") + + +def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: + if attn_mask is not None and is_causal: + raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") + + +def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.device != key.device or query.device != value.device: + raise ValueError("Query, key, and value must be on the same device.") + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError("Query, key, and value must have the same dtype.") + + +def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device(query, key, value) + if query.device.type != "cuda": + raise ValueError("Query, key, and value must be on a CUDA device.") + + +def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: + def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device_cuda(query, key, value) + if torch.cuda.get_device_capability(query.device) < (major, minor): + raise ValueError( + f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." + ) + + return check_device_cuda + + +def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.dtype != key.dtype: + raise ValueError("Query and key must have the same dtype.") + if query.dtype != value.dtype: + raise ValueError("Query and value must have the same dtype.") + + +def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_qkv_dtype_match(query, key, value) + if query.dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Query, key, and value must be either bfloat16 or float16.") + + +def _check_shape( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> None: + if query.shape[-1] != key.shape[-1]: + raise ValueError("Query and key must have the same last dimension.") + if query.shape[-2] != value.shape[-2]: + raise ValueError("Query and value must have the same second to last dimension.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: + raise ValueError("Attention mask must match the key's second to last dimension.") + + +def _prepare_for_flash_attn_or_sage_varlen( + batch_size: int, seq_len_q: int, attn_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None +) -> None: + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + if attn_mask is None: + seqlens_k = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + else: + seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: + """ + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in + FlashAttention/Sage varlen. + + Supports 1D to 4D shapes and common broadcasting patterns. + """ + if attn_mask.dtype != torch.bool: + raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") + + if attn_mask.ndim == 1: + # [seq_len_k] -> broadcast across batch + attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 2: + # [batch_size, seq_len_k]. Maybe broadcast across batch + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 3: + # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." + ) + attn_mask = attn_mask.any(dim=1) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 4: + # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] + attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] + + else: + raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") + + if attn_mask.shape != (batch_size, seq_len_k): + raise ValueError( + f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" + ) + + return attn_mask + + +# LRU cache for block mask creation since we don't want to recompile the same block mask multiple times +@functools.lru_cache +def _flex_attention_create_block_mask( + mask_mod, + batch_size: Optional[int] = None, + num_heads: Optional[int] = None, + seq_len_q: Optional[int] = None, + seq_len_kv: Optional[int] = None, + device: Optional[torch.device] = None, +) -> BlockMask: + global _IS_CREATE_BLOCK_MASK_COMPILED, create_block_mask + if is_torch_version(">=", "2.6.0"): + if not _IS_CREATE_BLOCK_MASK_COMPILED: + create_block_mask = torch.compile(create_block_mask) + _IS_CREATE_BLOCK_MASK_COMPILED = True + block_mask = create_block_mask(mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, device) + else: + block_mask = create_block_mask(mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, device, _compile=True) + return block_mask + + +def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLASH, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + if enable_gqa: + # TODO + pass + return flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLASH_VARLEN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if enable_gqa: + # TODO + pass + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = flash_attn_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3).contiguous() + + return out + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLEX, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _native_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[Union[torch.Tensor, BlockMask]] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + # TODO: should we LRU cache the block mask creation? + score_mod = None + block_mask = None + batch_size, num_heads, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is None or isinstance(attn_mask, BlockMask): + block_mask = attn_mask + elif is_causal: + block_mask = _flex_attention_create_block_mask( + _flex_attention_causal_mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + elif torch.is_tensor(attn_mask): + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + + attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) + + if attn_mask.dtype == torch.bool: + # TODO: this probably does not work but verify! + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return attn_mask[batch_idx, head_idx, q_idx, kv_idx] + + block_mask = _flex_attention_create_block_mask( + mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + else: + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] + else: + raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") + + return torch_flex_attention( + query=query, + key=key, + value=value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + enable_gqa=enable_gqa, + return_lse=return_lse, + kernel_options=None, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.NATIVE, + constraints=[_check_device, _check_shape], +) +def _native_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_CUDNN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_cudnn_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_EFFICIENT, + constraints=[_check_device, _check_shape], +) +def _native_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_FLASH, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_MATH, + constraints=[_check_device, _check_shape], +) +def _native_math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.SAGE, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.SAGE_VARLEN, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + is_causal: bool = False, + scale: Optional[float] = None, + smooth_k: bool = True, + attn_mask: Optional[torch.Tensor] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if enable_gqa: + # TODO + pass + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = sageattn_varlen( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3).contiguous() + + return out + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_sm90_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_cuda( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_triton( + q=query, + k=key, + v=value, + tensor_layout="HND", + quantization_backend=quantization_backend, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.XFORMERS, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _xformers_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, num_heads_q, seq_len_q, _ = query.shape + _, num_heads_kv, seq_len_kv, _ = key.shape + + # TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns + if is_causal: + attn_mask = xops.LowerTriangularMask() + elif attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + elif attn_mask.ndim != 4: + raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + + # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers + query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) + + if enable_gqa: + if num_heads_q % num_heads_kv != 0: + raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") + num_heads_per_group = num_heads_q // num_heads_kv + query = query.unflatten(2, (num_heads_kv, -1)) + key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + + out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + if enable_gqa: + out = out.flatten(2, 3) + + out = out.permute(0, 2, 1, 3).contiguous() + return out diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index a873a6ec9444..ee8de6aad3b9 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -24,6 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import LuminaFeedForward +from ..attention_dispatch import attention_dispatch from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -137,9 +138,7 @@ def __call__( key = key.transpose(1, 2) value = value.transpose(1, 2) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, scale=softmax_scale - ) + hidden_states = attention_dispatch(query, key, value, attn_mask=attention_mask, scale=softmax_scale) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index ed89955ba5a5..954a53af1e46 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -65,6 +65,8 @@ is_bitsandbytes_available, is_bitsandbytes_version, is_bs4_available, + is_flash_attn_available, + is_flash_attn_version, is_flax_available, is_ftfy_available, is_gguf_available, @@ -85,6 +87,8 @@ is_peft_available, is_peft_version, is_safetensors_available, + is_sageattention_available, + is_sageattention_version, is_scipy_available, is_sentencepiece_available, is_tensorboard_available, @@ -103,6 +107,7 @@ is_unidecode_available, is_wandb_available, is_xformers_available, + is_xformers_version, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 7c04287d33ed..ac91390609c1 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -41,6 +41,8 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 +DIFFUSERS_ATTN_PROVIDER = os.getenv("DIFFUSERS_ATTN_PROVIDER", "native") +DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index e8d9429f6204..d93aab030c34 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -190,6 +190,8 @@ def _is_package_available(pkg_name: str): _torchao_available, _torchao_version = _is_package_available("torchao") _bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") _torchao_available, _torchao_version = _is_package_available("torchao") +_sageattention_available, _sageattention_version = _is_package_available("sageattention") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _optimum_quanto_available = importlib.util.find_spec("optimum") is not None if _optimum_quanto_available: @@ -336,6 +338,14 @@ def is_timm_available(): return _timm_available +def is_sageattention_available(): + return _sageattention_available + + +def is_flash_attn_available(): + return _flash_attn_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -743,6 +753,51 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_xformers_version(operation: str, version: str): + """ + Compares the current xformers version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _xformers_available: + return False + return compare_versions(parse(_xformers_version), operation, version) + + +def is_sageattention_version(operation: str, version: str): + """ + Compares the current sageattention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _sageattention_available: + return False + return compare_versions(parse(_sageattention_version), operation, version) + + +def is_flash_attn_version(operation: str, version: str): + """ + Compares the current flash-attention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _flash_attn_available: + return False + return compare_versions(parse(_flash_attn_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects From bb0628a45bba8925f411fc24e1438af3e3961369 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 19 Apr 2025 23:58:21 +0200 Subject: [PATCH 2/5] fix wan vae dtype --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 20ad84cb90d0..a7818f9f2276 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -401,7 +401,7 @@ def prepare_latents( video_condition = torch.cat( [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 ) - video_condition = video_condition.to(device=device, dtype=dtype) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -421,6 +421,7 @@ def prepare_latents( latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = latent_condition.to(dtype=dtype) latent_condition = (latent_condition - latents_mean) * latents_std mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) From 03a7630c41160401621ab2dbc1af1350fc612258 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 20 Apr 2025 14:22:35 +0200 Subject: [PATCH 3/5] context manager for switching provider --- src/diffusers/__init__.py | 4 ++++ src/diffusers/models/__init__.py | 2 ++ src/diffusers/models/attention_dispatch.py | 20 ++++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 19 +++++++++++++++++++ 4 files changed, 45 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f51a4ef2b3f6..e454ae558123 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -143,6 +143,7 @@ [ "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", + "AttentionProvider", "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", @@ -212,6 +213,7 @@ "UVit2DModel", "VQModel", "WanTransformer3DModel", + "attention_provider", ] ) _import_structure["optimization"] = [ @@ -738,6 +740,7 @@ from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, + AttentionProvider, AuraFlowTransformer2DModel, AutoencoderDC, AutoencoderKL, @@ -806,6 +809,7 @@ UVit2DModel, VQModel, WanTransformer3DModel, + attention_provider, ) from .optimization import ( get_constant_schedule, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 276b1836a797..dfed6e7f300d 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,6 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] + _import_structure["attention_dispatch"] = ["AttentionProvider", "attention_provider"] _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] @@ -106,6 +107,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter + from .attention_dispatch import AttentionProvider, attention_provider from .auto_model import AutoModel from .autoencoders import ( AsymmetricAutoencoderKL, diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index be6c15509ce4..97902cb3bb05 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import inspect from enum import Enum @@ -123,6 +124,8 @@ class AttentionProvider(str, Enum): _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # TODO: let's not add support for Sparge Attention now because it requires tuning per model + # We can look into supporting something "autotune"-ing in the future # SPARGE = "sparge" # `xformers` @@ -157,6 +160,23 @@ def list_providers(cls): return list(cls._providers.keys()) +@contextlib.contextmanager +def attention_provider(provider: AttentionProvider = AttentionProvider.NATIVE): + """ + Context manager to set the active attention provider. + """ + if provider not in _AttentionProviderRegistry._providers: + raise ValueError(f"Provider {provider} is not registered.") + + old_provider = _AttentionProviderRegistry._active_provider + _AttentionProviderRegistry._active_provider = provider + + try: + yield + finally: + _AttentionProviderRegistry._active_provider = old_provider + + def attention_dispatch( query: torch.Tensor, key: torch.Tensor, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bf2f19ee2d26..1d5069e35f7b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -85,6 +85,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AttentionProvider(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AuraFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1105,6 +1120,10 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +def attention_provider(*args, **kwargs): + requires_backends(attention_provider, ["torch"]) + + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) From 4b201df72fea33f9620f2403090692c6c4ac6fef Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Apr 2025 14:43:27 +0200 Subject: [PATCH 4/5] fix flash/sage seqlen preparation when kv len does not match q len (cross attention) --- src/diffusers/models/attention_dispatch.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 97902cb3bb05..c30ee6043b97 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -276,11 +276,15 @@ def _check_shape( def _prepare_for_flash_attn_or_sage_varlen( - batch_size: int, seq_len_q: int, attn_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + attn_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, ) -> None: seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) if attn_mask is None: - seqlens_k = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) else: seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) @@ -440,7 +444,9 @@ def _flash_varlen_attention( if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device) + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) ) else: seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) @@ -730,7 +736,9 @@ def _sage_varlen_attention( if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device) + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) ) else: seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) From 9c4d4aa44c58f96fd323f40d5099f0a1a39bef06 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Apr 2025 22:10:40 +0200 Subject: [PATCH 5/5] fix flash-attn input shape bug; remove custom block mask code; update flux attention processors --- src/diffusers/models/attention_dispatch.py | 47 ++++++--------------- src/diffusers/models/attention_processor.py | 7 ++- 2 files changed, 16 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index c30ee6043b97..42444166c0e8 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib -import functools import inspect from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -96,7 +95,6 @@ def __init__(self, *args, **kwargs): logger = get_logger(__name__) # pylint: disable=invalid-name -_IS_CREATE_BLOCK_MASK_COMPILED = False _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] @@ -347,27 +345,6 @@ def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: in return attn_mask -# LRU cache for block mask creation since we don't want to recompile the same block mask multiple times -@functools.lru_cache -def _flex_attention_create_block_mask( - mask_mod, - batch_size: Optional[int] = None, - num_heads: Optional[int] = None, - seq_len_q: Optional[int] = None, - seq_len_kv: Optional[int] = None, - device: Optional[torch.device] = None, -) -> BlockMask: - global _IS_CREATE_BLOCK_MASK_COMPILED, create_block_mask - if is_torch_version(">=", "2.6.0"): - if not _IS_CREATE_BLOCK_MASK_COMPILED: - create_block_mask = torch.compile(create_block_mask) - _IS_CREATE_BLOCK_MASK_COMPILED = True - block_mask = create_block_mask(mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, device) - else: - block_mask = create_block_mask(mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, device, _compile=True) - return block_mask - - def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx @@ -394,7 +371,10 @@ def _flash_attention( if enable_gqa: # TODO pass - return flash_attn_func( + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + out = flash_attn_func( q=query, k=key, v=value, @@ -407,6 +387,8 @@ def _flash_attention( deterministic=deterministic, return_attn_probs=return_attn_probs, ) + out = out.permute(0, 2, 1, 3) + return out @_AttentionProviderRegistry.register( @@ -482,7 +464,7 @@ def _flash_varlen_attention( deterministic=deterministic, return_attn_probs=return_attn_probs, ) - out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3).contiguous() + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) return out @@ -512,8 +494,8 @@ def _native_flex_attention( if attn_mask is None or isinstance(attn_mask, BlockMask): block_mask = attn_mask elif is_causal: - block_mask = _flex_attention_create_block_mask( - _flex_attention_causal_mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + block_mask = create_block_mask( + _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device ) elif torch.is_tensor(attn_mask): if attn_mask.ndim == 2: @@ -526,9 +508,7 @@ def _native_flex_attention( def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return attn_mask[batch_idx, head_idx, q_idx, kv_idx] - block_mask = _flex_attention_create_block_mask( - mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device - ) + block_mask = create_block_mask(mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device) else: def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): @@ -769,7 +749,7 @@ def _sage_varlen_attention( sm_scale=scale, smooth_k=smooth_k, ) - out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3).contiguous() + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) return out @@ -920,7 +900,7 @@ def _xformers_attention( attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers - query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) if enable_gqa: if num_heads_q % num_heads_kv != 0: @@ -933,6 +913,5 @@ def _xformers_attention( out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) if enable_gqa: out = out.flatten(2, 3) - - out = out.permute(0, 2, 1, 3).contiguous() + out = out.permute(0, 2, 1, 3) return out diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 34276a544160..1e0387eadb57 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,6 +23,7 @@ from ..utils import deprecate, is_torch_xla_available, logging from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph +from .attention_dispatch import attention_dispatch logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -2339,9 +2340,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + hidden_states = attention_dispatch(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2543,7 +2542,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = attention_dispatch(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype)