From ab25ee0ac1cbf3dcd9ee212a8b4e0f0c17c3f413 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 19 Apr 2025 17:50:42 +0200 Subject: [PATCH 01/10] initial attention dispatcher support supported: flash, flash_varlen, flex, native, sage, sage_varlen, xformers --- src/diffusers/models/attention_dispatch.py | 910 ++++++++++++++++++ .../transformers/transformer_lumina2.py | 5 +- src/diffusers/utils/__init__.py | 5 + src/diffusers/utils/constants.py | 2 + src/diffusers/utils/import_utils.py | 55 ++ 5 files changed, 974 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/models/attention_dispatch.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py new file mode 100644 index 000000000000..be6c15509ce4 --- /dev/null +++ b/src/diffusers/models/attention_dispatch.py @@ -0,0 +1,910 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import inspect +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch + +from ..utils import ( + OptionalDependencyNotAvailable, + get_logger, + is_flash_attn_available, + is_flash_attn_version, + is_sageattention_available, + is_sageattention_version, + is_torch_version, + is_xformers_available, + is_xformers_version, +) +from ..utils.constants import DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ATTN_PROVIDER + + +if is_flash_attn_available(): + if is_flash_attn_version("<", "2.6.3"): + raise OptionalDependencyNotAvailable( + "The `flash-attn` library version is too old. Please update it to at least 2.6.3." + ) + + from flash_attn import flash_attn_func, flash_attn_varlen_func +else: + flash_attn_func = None + flash_attn_varlen_func = None + + +if is_sageattention_available(): + if is_sageattention_version("<", "2.1.1"): + raise OptionalDependencyNotAvailable( + "The `sageattention` library version is too old. Please update it to at least 2.1.1." + ) + + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, + sageattn_varlen, + ) +else: + sageattn = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_varlen = None + + +if is_torch_version(">=", "2.5.0"): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + from torch.nn.attention.flex_attention import flex_attention as torch_flex_attention +else: + create_block_mask = None + torch_flex_attention = None + + class BlockMask: + def __init__(self, *args, **kwargs): + raise OptionalDependencyNotAvailable( + "The `torch` library version is too old. Please update it to at least 2.5.0." + ) + + +if is_xformers_available(): + if is_xformers_version("<", "0.0.29"): + raise OptionalDependencyNotAvailable( + "The `xformers` library version is too old. Please update it to at least 0.0.29." + ) + + import xformers.ops as xops +else: + xops = None + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_IS_CREATE_BLOCK_MASK_COMPILED = False +_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] +_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] +_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] + + +class AttentionProvider(str, Enum): + # EAGER = "eager" + + # `flash-attn` + FLASH = "flash" + FLASH_VARLEN = "flash_varlen" + + # PyTorch native + FLEX = "flex" + NATIVE = "native" + _NATIVE_CUDNN = "_native_cudnn" + _NATIVE_EFFICIENT = "_native_efficient" + _NATIVE_FLASH = "_native_flash" + _NATIVE_MATH = "_native_math" + + # `sageattention` + SAGE = "sage" + SAGE_VARLEN = "sage_varlen" + _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" + _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" + _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" + _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # SPARGE = "sparge" + + # `xformers` + XFORMERS = "xformers" + + +class _AttentionProviderRegistry: + _providers = {} + _constraints = {} + _supported_arg_names = {} + _active_provider = AttentionProvider(DIFFUSERS_ATTN_PROVIDER) + _checks_enabled = DIFFUSERS_ATTN_CHECKS + + @classmethod + def register(cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None): + logger.debug(f"Registering attention provider: {provider} with constraints: {constraints}") + + def decorator(func): + cls._providers[provider] = func + cls._constraints[provider] = constraints or [] + cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) + return func + + return decorator + + @classmethod + def get_active_provider(cls): + return cls._active_provider, cls._providers[cls._active_provider] + + @classmethod + def list_providers(cls): + return list(cls._providers.keys()) + + +def attention_dispatch( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() + kwargs = { + "query": query, + "key": key, + "value": value, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + **attention_kwargs, + } + + if _AttentionProviderRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) + if removed_kwargs: + logger.warning(f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}.") + for check in _AttentionProviderRegistry._constraints.get(provider_name): + check(**kwargs) + + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} + return provider_fn(**kwargs) + + +def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: + if attn_mask is not None: + raise ValueError("Attention mask must be None for this provider.") + + +def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: + if attn_mask is not None and is_causal: + raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") + + +def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.device != key.device or query.device != value.device: + raise ValueError("Query, key, and value must be on the same device.") + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError("Query, key, and value must have the same dtype.") + + +def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device(query, key, value) + if query.device.type != "cuda": + raise ValueError("Query, key, and value must be on a CUDA device.") + + +def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: + def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device_cuda(query, key, value) + if torch.cuda.get_device_capability(query.device) < (major, minor): + raise ValueError( + f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." + ) + + return check_device_cuda + + +def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.dtype != key.dtype: + raise ValueError("Query and key must have the same dtype.") + if query.dtype != value.dtype: + raise ValueError("Query and value must have the same dtype.") + + +def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_qkv_dtype_match(query, key, value) + if query.dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Query, key, and value must be either bfloat16 or float16.") + + +def _check_shape( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> None: + if query.shape[-1] != key.shape[-1]: + raise ValueError("Query and key must have the same last dimension.") + if query.shape[-2] != value.shape[-2]: + raise ValueError("Query and value must have the same second to last dimension.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: + raise ValueError("Attention mask must match the key's second to last dimension.") + + +def _prepare_for_flash_attn_or_sage_varlen( + batch_size: int, seq_len_q: int, attn_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None +) -> None: + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + if attn_mask is None: + seqlens_k = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + else: + seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: + """ + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in + FlashAttention/Sage varlen. + + Supports 1D to 4D shapes and common broadcasting patterns. + """ + if attn_mask.dtype != torch.bool: + raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") + + if attn_mask.ndim == 1: + # [seq_len_k] -> broadcast across batch + attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 2: + # [batch_size, seq_len_k]. Maybe broadcast across batch + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 3: + # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." + ) + attn_mask = attn_mask.any(dim=1) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 4: + # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] + attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] + + else: + raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") + + if attn_mask.shape != (batch_size, seq_len_k): + raise ValueError( + f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" + ) + + return attn_mask + + +# LRU cache for block mask creation since we don't want to recompile the same block mask multiple times +@functools.lru_cache +def _flex_attention_create_block_mask( + mask_mod, + batch_size: Optional[int] = None, + num_heads: Optional[int] = None, + seq_len_q: Optional[int] = None, + seq_len_kv: Optional[int] = None, + device: Optional[torch.device] = None, +) -> BlockMask: + global _IS_CREATE_BLOCK_MASK_COMPILED, create_block_mask + if is_torch_version(">=", "2.6.0"): + if not _IS_CREATE_BLOCK_MASK_COMPILED: + create_block_mask = torch.compile(create_block_mask) + _IS_CREATE_BLOCK_MASK_COMPILED = True + block_mask = create_block_mask(mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, device) + else: + block_mask = create_block_mask(mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, device, _compile=True) + return block_mask + + +def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLASH, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + if enable_gqa: + # TODO + pass + return flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLASH_VARLEN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if enable_gqa: + # TODO + pass + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = flash_attn_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3).contiguous() + + return out + + +@_AttentionProviderRegistry.register( + AttentionProvider.FLEX, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _native_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[Union[torch.Tensor, BlockMask]] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + # TODO: should we LRU cache the block mask creation? + score_mod = None + block_mask = None + batch_size, num_heads, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is None or isinstance(attn_mask, BlockMask): + block_mask = attn_mask + elif is_causal: + block_mask = _flex_attention_create_block_mask( + _flex_attention_causal_mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + elif torch.is_tensor(attn_mask): + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + + attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) + + if attn_mask.dtype == torch.bool: + # TODO: this probably does not work but verify! + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return attn_mask[batch_idx, head_idx, q_idx, kv_idx] + + block_mask = _flex_attention_create_block_mask( + mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + else: + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] + else: + raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") + + return torch_flex_attention( + query=query, + key=key, + value=value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + enable_gqa=enable_gqa, + return_lse=return_lse, + kernel_options=None, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.NATIVE, + constraints=[_check_device, _check_shape], +) +def _native_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_CUDNN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_cudnn_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_EFFICIENT, + constraints=[_check_device, _check_shape], +) +def _native_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_FLASH, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._NATIVE_MATH, + constraints=[_check_device, _check_shape], +) +def _native_math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.SAGE, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.SAGE_VARLEN, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + is_causal: bool = False, + scale: Optional[float] = None, + smooth_k: bool = True, + attn_mask: Optional[torch.Tensor] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if enable_gqa: + # TODO + pass + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = sageattn_varlen( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3).contiguous() + + return out + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_sm90_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_cuda( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_triton( + q=query, + k=key, + v=value, + tensor_layout="HND", + quantization_backend=quantization_backend, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionProviderRegistry.register( + AttentionProvider.XFORMERS, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _xformers_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, num_heads_q, seq_len_q, _ = query.shape + _, num_heads_kv, seq_len_kv, _ = key.shape + + # TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns + if is_causal: + attn_mask = xops.LowerTriangularMask() + elif attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + elif attn_mask.ndim != 4: + raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + + # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers + query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) + + if enable_gqa: + if num_heads_q % num_heads_kv != 0: + raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") + num_heads_per_group = num_heads_q // num_heads_kv + query = query.unflatten(2, (num_heads_kv, -1)) + key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + + out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + if enable_gqa: + out = out.flatten(2, 3) + + out = out.permute(0, 2, 1, 3).contiguous() + return out diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index a873a6ec9444..ee8de6aad3b9 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -24,6 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import LuminaFeedForward +from ..attention_dispatch import attention_dispatch from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -137,9 +138,7 @@ def __call__( key = key.transpose(1, 2) value = value.transpose(1, 2) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, scale=softmax_scale - ) + hidden_states = attention_dispatch(query, key, value, attn_mask=attention_mask, scale=softmax_scale) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index ed89955ba5a5..954a53af1e46 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -65,6 +65,8 @@ is_bitsandbytes_available, is_bitsandbytes_version, is_bs4_available, + is_flash_attn_available, + is_flash_attn_version, is_flax_available, is_ftfy_available, is_gguf_available, @@ -85,6 +87,8 @@ is_peft_available, is_peft_version, is_safetensors_available, + is_sageattention_available, + is_sageattention_version, is_scipy_available, is_sentencepiece_available, is_tensorboard_available, @@ -103,6 +107,7 @@ is_unidecode_available, is_wandb_available, is_xformers_available, + is_xformers_version, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 7c04287d33ed..ac91390609c1 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -41,6 +41,8 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 +DIFFUSERS_ATTN_PROVIDER = os.getenv("DIFFUSERS_ATTN_PROVIDER", "native") +DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index e8d9429f6204..d93aab030c34 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -190,6 +190,8 @@ def _is_package_available(pkg_name: str): _torchao_available, _torchao_version = _is_package_available("torchao") _bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") _torchao_available, _torchao_version = _is_package_available("torchao") +_sageattention_available, _sageattention_version = _is_package_available("sageattention") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _optimum_quanto_available = importlib.util.find_spec("optimum") is not None if _optimum_quanto_available: @@ -336,6 +338,14 @@ def is_timm_available(): return _timm_available +def is_sageattention_available(): + return _sageattention_available + + +def is_flash_attn_available(): + return _flash_attn_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -743,6 +753,51 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_xformers_version(operation: str, version: str): + """ + Compares the current xformers version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _xformers_available: + return False + return compare_versions(parse(_xformers_version), operation, version) + + +def is_sageattention_version(operation: str, version: str): + """ + Compares the current sageattention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _sageattention_available: + return False + return compare_versions(parse(_sageattention_version), operation, version) + + +def is_flash_attn_version(operation: str, version: str): + """ + Compares the current flash-attention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _flash_attn_available: + return False + return compare_versions(parse(_flash_attn_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects From bb0628a45bba8925f411fc24e1438af3e3961369 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 19 Apr 2025 23:58:21 +0200 Subject: [PATCH 02/10] fix wan vae dtype --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 20ad84cb90d0..a7818f9f2276 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -401,7 +401,7 @@ def prepare_latents( video_condition = torch.cat( [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 ) - video_condition = video_condition.to(device=device, dtype=dtype) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -421,6 +421,7 @@ def prepare_latents( latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = latent_condition.to(dtype=dtype) latent_condition = (latent_condition - latents_mean) * latents_std mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) From 03a7630c41160401621ab2dbc1af1350fc612258 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 20 Apr 2025 14:22:35 +0200 Subject: [PATCH 03/10] context manager for switching provider --- src/diffusers/__init__.py | 4 ++++ src/diffusers/models/__init__.py | 2 ++ src/diffusers/models/attention_dispatch.py | 20 ++++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 19 +++++++++++++++++++ 4 files changed, 45 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f51a4ef2b3f6..e454ae558123 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -143,6 +143,7 @@ [ "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", + "AttentionProvider", "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", @@ -212,6 +213,7 @@ "UVit2DModel", "VQModel", "WanTransformer3DModel", + "attention_provider", ] ) _import_structure["optimization"] = [ @@ -738,6 +740,7 @@ from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, + AttentionProvider, AuraFlowTransformer2DModel, AutoencoderDC, AutoencoderKL, @@ -806,6 +809,7 @@ UVit2DModel, VQModel, WanTransformer3DModel, + attention_provider, ) from .optimization import ( get_constant_schedule, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 276b1836a797..dfed6e7f300d 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,6 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] + _import_structure["attention_dispatch"] = ["AttentionProvider", "attention_provider"] _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] @@ -106,6 +107,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter + from .attention_dispatch import AttentionProvider, attention_provider from .auto_model import AutoModel from .autoencoders import ( AsymmetricAutoencoderKL, diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index be6c15509ce4..97902cb3bb05 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import inspect from enum import Enum @@ -123,6 +124,8 @@ class AttentionProvider(str, Enum): _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # TODO: let's not add support for Sparge Attention now because it requires tuning per model + # We can look into supporting something "autotune"-ing in the future # SPARGE = "sparge" # `xformers` @@ -157,6 +160,23 @@ def list_providers(cls): return list(cls._providers.keys()) +@contextlib.contextmanager +def attention_provider(provider: AttentionProvider = AttentionProvider.NATIVE): + """ + Context manager to set the active attention provider. + """ + if provider not in _AttentionProviderRegistry._providers: + raise ValueError(f"Provider {provider} is not registered.") + + old_provider = _AttentionProviderRegistry._active_provider + _AttentionProviderRegistry._active_provider = provider + + try: + yield + finally: + _AttentionProviderRegistry._active_provider = old_provider + + def attention_dispatch( query: torch.Tensor, key: torch.Tensor, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bf2f19ee2d26..1d5069e35f7b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -85,6 +85,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AttentionProvider(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AuraFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1105,6 +1120,10 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +def attention_provider(*args, **kwargs): + requires_backends(attention_provider, ["torch"]) + + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) From 4b201df72fea33f9620f2403090692c6c4ac6fef Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Apr 2025 14:43:27 +0200 Subject: [PATCH 04/10] fix flash/sage seqlen preparation when kv len does not match q len (cross attention) --- src/diffusers/models/attention_dispatch.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 97902cb3bb05..c30ee6043b97 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -276,11 +276,15 @@ def _check_shape( def _prepare_for_flash_attn_or_sage_varlen( - batch_size: int, seq_len_q: int, attn_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + attn_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, ) -> None: seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) if attn_mask is None: - seqlens_k = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) else: seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) @@ -440,7 +444,9 @@ def _flash_varlen_attention( if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device) + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) ) else: seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) @@ -730,7 +736,9 @@ def _sage_varlen_attention( if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device) + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) ) else: seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) From 9c4d4aa44c58f96fd323f40d5099f0a1a39bef06 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Apr 2025 22:10:40 +0200 Subject: [PATCH 05/10] fix flash-attn input shape bug; remove custom block mask code; update flux attention processors --- src/diffusers/models/attention_dispatch.py | 47 ++++++--------------- src/diffusers/models/attention_processor.py | 7 ++- 2 files changed, 16 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index c30ee6043b97..42444166c0e8 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib -import functools import inspect from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -96,7 +95,6 @@ def __init__(self, *args, **kwargs): logger = get_logger(__name__) # pylint: disable=invalid-name -_IS_CREATE_BLOCK_MASK_COMPILED = False _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] @@ -347,27 +345,6 @@ def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: in return attn_mask -# LRU cache for block mask creation since we don't want to recompile the same block mask multiple times -@functools.lru_cache -def _flex_attention_create_block_mask( - mask_mod, - batch_size: Optional[int] = None, - num_heads: Optional[int] = None, - seq_len_q: Optional[int] = None, - seq_len_kv: Optional[int] = None, - device: Optional[torch.device] = None, -) -> BlockMask: - global _IS_CREATE_BLOCK_MASK_COMPILED, create_block_mask - if is_torch_version(">=", "2.6.0"): - if not _IS_CREATE_BLOCK_MASK_COMPILED: - create_block_mask = torch.compile(create_block_mask) - _IS_CREATE_BLOCK_MASK_COMPILED = True - block_mask = create_block_mask(mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, device) - else: - block_mask = create_block_mask(mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, device, _compile=True) - return block_mask - - def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx @@ -394,7 +371,10 @@ def _flash_attention( if enable_gqa: # TODO pass - return flash_attn_func( + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + out = flash_attn_func( q=query, k=key, v=value, @@ -407,6 +387,8 @@ def _flash_attention( deterministic=deterministic, return_attn_probs=return_attn_probs, ) + out = out.permute(0, 2, 1, 3) + return out @_AttentionProviderRegistry.register( @@ -482,7 +464,7 @@ def _flash_varlen_attention( deterministic=deterministic, return_attn_probs=return_attn_probs, ) - out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3).contiguous() + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) return out @@ -512,8 +494,8 @@ def _native_flex_attention( if attn_mask is None or isinstance(attn_mask, BlockMask): block_mask = attn_mask elif is_causal: - block_mask = _flex_attention_create_block_mask( - _flex_attention_causal_mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + block_mask = create_block_mask( + _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device ) elif torch.is_tensor(attn_mask): if attn_mask.ndim == 2: @@ -526,9 +508,7 @@ def _native_flex_attention( def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return attn_mask[batch_idx, head_idx, q_idx, kv_idx] - block_mask = _flex_attention_create_block_mask( - mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device - ) + block_mask = create_block_mask(mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device) else: def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): @@ -769,7 +749,7 @@ def _sage_varlen_attention( sm_scale=scale, smooth_k=smooth_k, ) - out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3).contiguous() + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) return out @@ -920,7 +900,7 @@ def _xformers_attention( attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers - query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) if enable_gqa: if num_heads_q % num_heads_kv != 0: @@ -933,6 +913,5 @@ def _xformers_attention( out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) if enable_gqa: out = out.flatten(2, 3) - - out = out.permute(0, 2, 1, 3).contiguous() + out = out.permute(0, 2, 1, 3) return out diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 34276a544160..1e0387eadb57 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,6 +23,7 @@ from ..utils import deprecate, is_torch_xla_available, logging from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph +from .attention_dispatch import attention_dispatch logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -2339,9 +2340,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + hidden_states = attention_dispatch(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2543,7 +2542,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = attention_dispatch(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From a289d301d4119514e50b044e9f1cc8a1c8c10004 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 14 May 2025 14:53:08 +0200 Subject: [PATCH 06/10] provider->backend --- src/diffusers/__init__.py | 8 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/attention_dispatch.py | 120 ++++++++++----------- src/diffusers/utils/constants.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 6 +- 5 files changed, 70 insertions(+), 70 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e186c62315ac..b9218dc62eea 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -143,7 +143,7 @@ [ "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", - "AttentionProvider", + "AttentionBackend", "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", @@ -216,7 +216,7 @@ "UVit2DModel", "VQModel", "WanTransformer3DModel", - "attention_provider", + "attention_backend", ] ) _import_structure["optimization"] = [ @@ -750,7 +750,7 @@ from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, - AttentionProvider, + AttentionBackendName, AuraFlowTransformer2DModel, AutoencoderDC, AutoencoderKL, @@ -822,7 +822,7 @@ UVit2DModel, VQModel, WanTransformer3DModel, - attention_provider, + attention_backend, ) from .optimization import ( get_constant_schedule, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c0b7bb08188c..937135aa0182 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,7 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] - _import_structure["attention_dispatch"] = ["AttentionProvider", "attention_provider"] + _import_structure["attention_dispatch"] = ["AttentionBackend", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] @@ -110,7 +110,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter - from .attention_dispatch import AttentionProvider, attention_provider + from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel from .autoencoders import ( AsymmetricAutoencoderKL, diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 42444166c0e8..8ab303c22b33 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -30,7 +30,7 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ATTN_PROVIDER +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS if is_flash_attn_available(): @@ -78,7 +78,7 @@ class BlockMask: def __init__(self, *args, **kwargs): raise OptionalDependencyNotAvailable( - "The `torch` library version is too old. Please update it to at least 2.5.0." + "The `torch` library version is too old for using Flex Attention. Please update it to at least 2.5.0." ) @@ -100,7 +100,7 @@ def __init__(self, *args, **kwargs): _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] -class AttentionProvider(str, Enum): +class AttentionBackendName(str, Enum): # EAGER = "eager" # `flash-attn` @@ -130,49 +130,49 @@ class AttentionProvider(str, Enum): XFORMERS = "xformers" -class _AttentionProviderRegistry: - _providers = {} +class _AttentionBackendRegistry: + _backends = {} _constraints = {} _supported_arg_names = {} - _active_provider = AttentionProvider(DIFFUSERS_ATTN_PROVIDER) + _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _checks_enabled = DIFFUSERS_ATTN_CHECKS @classmethod - def register(cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None): - logger.debug(f"Registering attention provider: {provider} with constraints: {constraints}") + def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): + logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") def decorator(func): - cls._providers[provider] = func - cls._constraints[provider] = constraints or [] - cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) + cls._backends[backend] = func + cls._constraints[backend] = constraints or [] + cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) return func return decorator @classmethod - def get_active_provider(cls): - return cls._active_provider, cls._providers[cls._active_provider] + def get_active_backend(cls): + return cls._active_backend, cls._backends[cls._active_backend] @classmethod - def list_providers(cls): - return list(cls._providers.keys()) + def list_backends(cls): + return list(cls._backends.keys()) @contextlib.contextmanager -def attention_provider(provider: AttentionProvider = AttentionProvider.NATIVE): +def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): """ - Context manager to set the active attention provider. + Context manager to set the active attention backend. """ - if provider not in _AttentionProviderRegistry._providers: - raise ValueError(f"Provider {provider} is not registered.") + if backend not in _AttentionBackendRegistry._backends: + raise ValueError(f"Backend {backend} is not registered.") - old_provider = _AttentionProviderRegistry._active_provider - _AttentionProviderRegistry._active_provider = provider + old_backend = _AttentionBackendRegistry._active_backend + _AttentionBackendRegistry._active_backend = backend try: yield finally: - _AttentionProviderRegistry._active_provider = old_provider + _AttentionBackendRegistry._active_backend = old_backend def attention_dispatch( @@ -187,7 +187,7 @@ def attention_dispatch( attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: attention_kwargs = attention_kwargs or {} - provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() + backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() kwargs = { "query": query, "key": key, @@ -200,20 +200,20 @@ def attention_dispatch( **attention_kwargs, } - if _AttentionProviderRegistry._checks_enabled: - removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) + if _AttentionBackendRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) if removed_kwargs: - logger.warning(f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}.") - for check in _AttentionProviderRegistry._constraints.get(provider_name): + logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") + for check in _AttentionBackendRegistry._constraints.get(backend_name): check(**kwargs) - kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} - return provider_fn(**kwargs) + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + return backend_fn(**kwargs) def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: if attn_mask is not None: - raise ValueError("Attention mask must be None for this provider.") + raise ValueError("Attention mask must be None for this backend.") def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: @@ -349,8 +349,8 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx -@_AttentionProviderRegistry.register( - AttentionProvider.FLASH, +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH, constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _flash_attention( @@ -391,8 +391,8 @@ def _flash_attention( return out -@_AttentionProviderRegistry.register( - AttentionProvider.FLASH_VARLEN, +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_VARLEN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _flash_varlen_attention( @@ -469,8 +469,8 @@ def _flash_varlen_attention( return out -@_AttentionProviderRegistry.register( - AttentionProvider.FLEX, +@_AttentionBackendRegistry.register( + AttentionBackendName.FLEX, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], ) def _native_flex_attention( @@ -529,8 +529,8 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): ) -@_AttentionProviderRegistry.register( - AttentionProvider.NATIVE, +@_AttentionBackendRegistry.register( + AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], ) def _native_attention( @@ -555,8 +555,8 @@ def _native_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider._NATIVE_CUDNN, +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_CUDNN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _native_cudnn_attention( @@ -582,8 +582,8 @@ def _native_cudnn_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider._NATIVE_EFFICIENT, +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_EFFICIENT, constraints=[_check_device, _check_shape], ) def _native_efficient_attention( @@ -609,8 +609,8 @@ def _native_efficient_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider._NATIVE_FLASH, +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_FLASH, constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _native_flash_attention( @@ -636,8 +636,8 @@ def _native_flash_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider._NATIVE_MATH, +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_MATH, constraints=[_check_device, _check_shape], ) def _native_math_attention( @@ -663,8 +663,8 @@ def _native_math_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider.SAGE, +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _sage_attention( @@ -686,8 +686,8 @@ def _sage_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider.SAGE_VARLEN, +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _sage_varlen_attention( @@ -754,8 +754,8 @@ def _sage_varlen_attention( return out -@_AttentionProviderRegistry.register( - AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], ) def _sage_qk_int8_pv_fp8_cuda_attention( @@ -785,8 +785,8 @@ def _sage_qk_int8_pv_fp8_cuda_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90, +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], ) def _sage_qk_int8_pv_fp8_cuda_sm90_attention( @@ -814,8 +814,8 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], ) def _sage_qk_int8_pv_fp16_cuda_attention( @@ -845,8 +845,8 @@ def _sage_qk_int8_pv_fp16_cuda_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], ) def _sage_qk_int8_pv_fp16_triton_attention( @@ -872,8 +872,8 @@ def _sage_qk_int8_pv_fp16_triton_attention( ) -@_AttentionProviderRegistry.register( - AttentionProvider.XFORMERS, +@_AttentionBackendRegistry.register( + AttentionBackendName.XFORMERS, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], ) def _xformers_attention( diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index ac91390609c1..f8f04cc03abd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -41,7 +41,7 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 -DIFFUSERS_ATTN_PROVIDER = os.getenv("DIFFUSERS_ATTN_PROVIDER", "native") +DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 1b0d544b82ab..3933fec29656 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -85,7 +85,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class AttentionProvider(metaclass=DummyObject): +class AttentionBackendName(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1165,8 +1165,8 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -def attention_provider(*args, **kwargs): - requires_backends(attention_provider, ["torch"]) +def attention_backend(*args, **kwargs): + requires_backends(attention_backend, ["torch"]) def get_constant_schedule(*args, **kwargs): From 043096690ea26b8169d489c8a1cc942dd03b1586 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 15 May 2025 12:07:45 +0200 Subject: [PATCH 07/10] refactor; flex attention fixes for compile compatibility --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/attention_dispatch.py | 60 ++++++++++--------- src/diffusers/models/attention_processor.py | 8 ++- .../transformers/transformer_lumina2.py | 4 +- 5 files changed, 41 insertions(+), 35 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b9218dc62eea..1a886017dde4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -143,7 +143,7 @@ [ "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", - "AttentionBackend", + "AttentionBackendName", "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 937135aa0182..b75d4ae8a542 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,7 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] - _import_structure["attention_dispatch"] = ["AttentionBackend", "attention_backend"] + _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 8ab303c22b33..b83d242601b2 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib +import functools import inspect from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -69,17 +70,10 @@ if is_torch_version(">=", "2.5.0"): - from torch.nn.attention.flex_attention import BlockMask, create_block_mask - from torch.nn.attention.flex_attention import flex_attention as torch_flex_attention -else: - create_block_mask = None - torch_flex_attention = None - - class BlockMask: - def __init__(self, *args, **kwargs): - raise OptionalDependencyNotAvailable( - "The `torch` library version is too old for using Flex Attention. Please update it to at least 2.5.0." - ) + # We cannot import the flex_attention function from the package directly because it is expected (from the + # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the + # compiled function. + import torch.nn.attention.flex_attention as flex_attention if is_xformers_available(): @@ -175,7 +169,7 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV _AttentionBackendRegistry._active_backend = old_backend -def attention_dispatch( +def dispatch_attention_fn( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -211,6 +205,10 @@ def attention_dispatch( return backend_fn(**kwargs) +# ===== Checks ===== +# A list of very simple functions to catch common errors quickly when debugging. + + def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: if attn_mask is not None: raise ValueError("Attention mask must be None for this backend.") @@ -273,6 +271,10 @@ def _check_shape( raise ValueError("Attention mask must match the key's second to last dimension.") +# ===== Helper functions ===== + + +@functools.lru_cache(maxsize=1) def _prepare_for_flash_attn_or_sage_varlen( batch_size: int, seq_len_q: int, @@ -296,7 +298,7 @@ def _prepare_for_flash_attn_or_sage_varlen( def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: """ - Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in FlashAttention/Sage varlen. Supports 1D to 4D shapes and common broadcasting patterns. @@ -318,6 +320,7 @@ def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: in elif attn_mask.ndim == 3: # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen. if attn_mask.size(0) not in [1, batch_size]: raise ValueError( f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." @@ -349,6 +352,9 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx +# ===== Attention backends ===== + + @_AttentionBackendRegistry.register( AttentionBackendName.FLASH, constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], @@ -369,11 +375,9 @@ def _flash_attention( enable_gqa: bool = False, ) -> torch.Tensor: if enable_gqa: - # TODO - pass + raise NotImplementedError("GQA is not yet supported.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = flash_attn_func( q=query, k=key, @@ -388,6 +392,7 @@ def _flash_attention( return_attn_probs=return_attn_probs, ) out = out.permute(0, 2, 1, 3) + return out @@ -421,8 +426,7 @@ def _flash_varlen_attention( attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) if enable_gqa: - # TODO - pass + raise NotImplementedError("GQA is not yet supported.") if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( @@ -477,8 +481,7 @@ def _native_flex_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_mask: Optional[Union[torch.Tensor, BlockMask]] = None, - dropout_p: float = 0.0, + attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, @@ -491,10 +494,10 @@ def _native_flex_attention( batch_size, num_heads, seq_len_q, _ = query.shape _, _, seq_len_kv, _ = key.shape - if attn_mask is None or isinstance(attn_mask, BlockMask): + if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): block_mask = attn_mask elif is_causal: - block_mask = create_block_mask( + block_mask = flex_attention.create_block_mask( _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device ) elif torch.is_tensor(attn_mask): @@ -508,7 +511,9 @@ def _native_flex_attention( def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return attn_mask[batch_idx, head_idx, q_idx, kv_idx] - block_mask = create_block_mask(mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device) + block_mask = flex_attention.create_block_mask( + mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) else: def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): @@ -516,7 +521,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): else: raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") - return torch_flex_attention( + return flex_attention.flex_attention( query=query, key=key, value=value, @@ -525,7 +530,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): scale=scale, enable_gqa=enable_gqa, return_lse=return_lse, - kernel_options=None, + kernel_options=kernel_options, ) @@ -711,8 +716,7 @@ def _sage_varlen_attention( attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) if enable_gqa: - # TODO - pass + raise NotImplementedError("GQA is not yet supported.") if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( @@ -825,7 +829,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention( is_causal: bool = False, scale: Optional[float] = None, qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", - pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32", smooth_k: bool = True, smooth_v: bool = False, return_lse: bool = False, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a9e2542ddc57..a7307810a301 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,7 +23,7 @@ from ..utils import deprecate, is_torch_xla_available, logging from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph -from .attention_dispatch import attention_dispatch +from .attention_dispatch import dispatch_attention_fn logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -2340,7 +2340,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = attention_dispatch(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2542,7 +2544,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = attention_dispatch(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index ee8de6aad3b9..ffa72294ade5 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -24,7 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import LuminaFeedForward -from ..attention_dispatch import attention_dispatch +from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -138,7 +138,7 @@ def __call__( key = key.transpose(1, 2) value = value.transpose(1, 2) - hidden_states = attention_dispatch(query, key, value, attn_mask=attention_mask, scale=softmax_scale) + hidden_states = dispatch_attention_fn(query, key, value, attn_mask=attention_mask, scale=softmax_scale) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.type_as(query) From 66deeac4b5f893c6431f58af4cc6f9c3e2d6b7b6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 16 May 2025 14:42:15 +0200 Subject: [PATCH 08/10] support more models --- src/diffusers/models/attention_processor.py | 22 +++++++++---------- .../models/transformers/transformer_wan.py | 5 +++-- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a7307810a301..7a9f24814c3f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1074,9 +1074,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) - attn_output = F.scaled_dot_product_attention( - valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False - ) + attn_output = dispatch_attention_fn(valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False) valid_sequence_length = attn_output.size(2) attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) attn_outputs.append(attn_output) @@ -2450,7 +2448,7 @@ def __call__( inner_precise=0, )[0] else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2655,7 +2653,7 @@ def __call__( inner_precise=0, )[0] else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2777,7 +2775,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2807,7 +2805,7 @@ def __call__( ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( + current_ip_hidden_states = dispatch_attention_fn( ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( @@ -2873,7 +2871,7 @@ def __call__( if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -2944,7 +2942,7 @@ def __call__( if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -3217,7 +3215,7 @@ def __call__( )[0] else: # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -3311,7 +3309,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -3615,7 +3613,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal ) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index c78d72dc4a2c..e099a48484bc 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -23,6 +23,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed @@ -89,13 +90,13 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - hidden_states_img = F.scaled_dot_product_attention( + hidden_states_img = dispatch_attention_fn( query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False ) hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - hidden_states = F.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) From 45a809aa4b31b73895af357bec9c76e0e5518b3a Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 30 May 2025 07:01:14 +0200 Subject: [PATCH 09/10] add support for FA3, NPU, XLA --- src/diffusers/models/attention_dispatch.py | 235 ++++++++++++++++++--- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 5 + 3 files changed, 207 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index b83d242601b2..fee1184876b4 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,43 +15,49 @@ import contextlib import functools import inspect +import math from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch from ..utils import ( - OptionalDependencyNotAvailable, get_logger, + is_flash_attn_3_available, is_flash_attn_available, is_flash_attn_version, is_sageattention_available, is_sageattention_version, + is_torch_npu_available, is_torch_version, + is_torch_xla_available, + is_torch_xla_version, is_xformers_available, is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -if is_flash_attn_available(): - if is_flash_attn_version("<", "2.6.3"): - raise OptionalDependencyNotAvailable( - "The `flash-attn` library version is too old. Please update it to at least 2.6.3." - ) +logger = get_logger(__name__) # pylint: disable=invalid-name + +if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): from flash_attn import flash_attn_func, flash_attn_varlen_func else: + logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") flash_attn_func = None flash_attn_varlen_func = None -if is_sageattention_available(): - if is_sageattention_version("<", "2.1.1"): - raise OptionalDependencyNotAvailable( - "The `sageattention` library version is too old. Please update it to at least 2.1.1." - ) +if is_flash_attn_3_available(): + from flash_attn_interface import flash_attn_func as flash_attn_3_func + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func +else: + flash_attn_3_func = None + flash_attn_3_varlen_func = None + +if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): from sageattention import ( sageattn, sageattn_qk_int8_pv_fp8_cuda, @@ -61,6 +67,9 @@ sageattn_varlen, ) else: + logger.warning( + "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." + ) sageattn = None sageattn_qk_int8_pv_fp16_cuda = None sageattn_qk_int8_pv_fp16_triton = None @@ -76,19 +85,25 @@ import torch.nn.attention.flex_attention as flex_attention -if is_xformers_available(): - if is_xformers_version("<", "0.0.29"): - raise OptionalDependencyNotAvailable( - "The `xformers` library version is too old. Please update it to at least 0.0.29." - ) +if is_torch_npu_available(): + from torch_npu import npu_fusion_attention +else: + npu_fusion_attention = None + +if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): + from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention +else: + xla_flash_attention = None + + +if is_xformers_available() and is_xformers_version(">=", "0.0.29"): import xformers.ops as xops else: + logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") xops = None -logger = get_logger(__name__) # pylint: disable=invalid-name - _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] @@ -100,6 +115,8 @@ class AttentionBackendName(str, Enum): # `flash-attn` FLASH = "flash" FLASH_VARLEN = "flash_varlen" + _FLASH_3 = "_flash_3" + _FLASH_VARLEN_3 = "_flash_varlen_3" # PyTorch native FLEX = "flex" @@ -108,6 +125,8 @@ class AttentionBackendName(str, Enum): _NATIVE_EFFICIENT = "_native_efficient" _NATIVE_FLASH = "_native_flash" _NATIVE_MATH = "_native_math" + _NATIVE_NPU = "_native_npu" + _NATIVE_XLA = "_native_xla" # `sageattention` SAGE = "sage" @@ -274,7 +293,7 @@ def _check_shape( # ===== Helper functions ===== -@functools.lru_cache(maxsize=1) +@functools.lru_cache(maxsize=8) def _prepare_for_flash_attn_or_sage_varlen( batch_size: int, seq_len_q: int, @@ -371,12 +390,7 @@ def _flash_attention( alibi_slopes: Optional[torch.Tensor] = None, deterministic: bool = False, return_attn_probs: bool = False, - attn_mask: Optional[torch.Tensor] = None, - enable_gqa: bool = False, ) -> torch.Tensor: - if enable_gqa: - raise NotImplementedError("GQA is not yet supported.") - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = flash_attn_func( q=query, @@ -392,7 +406,6 @@ def _flash_attention( return_attn_probs=return_attn_probs, ) out = out.permute(0, 2, 1, 3) - return out @@ -417,7 +430,6 @@ def _flash_varlen_attention( deterministic: bool = False, return_attn_probs: bool = False, attn_mask: Optional[torch.Tensor] = None, - enable_gqa: bool = False, ) -> torch.Tensor: batch_size, _, seq_len_q, _ = query.shape _, _, seq_len_kv, _ = key.shape @@ -425,9 +437,6 @@ def _flash_varlen_attention( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if enable_gqa: - raise NotImplementedError("GQA is not yet supported.") - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( @@ -473,6 +482,121 @@ def _flash_varlen_attention( return out +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_3, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out, lse, *_ = flash_attn_3_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + attention_chunk=0, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + out = out.permute(0, 2, 1, 3) + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_VARLEN_3, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out, lse, *_ = flash_attn_3_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + + return (out, lse) if return_attn_probs else out + + @_AttentionBackendRegistry.register( AttentionBackendName.FLEX, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], @@ -668,6 +792,53 @@ def _native_math_attention( ) +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_NPU, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_npu_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, +) -> torch.Tensor: + return npu_fusion_attention( + query, + key, + value, + query.size(1), # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tokens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + +# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_XLA, + constraints=[_check_device, _check_shape], +) +def _native_xla_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, +) -> torch.Tensor: + query = query / math.sqrt(query.shape[-1]) + return xla_flash_attention( + q=query, + k=key, + v=value, + causal=is_causal, + ) + + @_AttentionBackendRegistry.register( AttentionBackendName.SAGE, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], @@ -707,7 +878,6 @@ def _sage_varlen_attention( scale: Optional[float] = None, smooth_k: bool = True, attn_mask: Optional[torch.Tensor] = None, - enable_gqa: bool = False, ) -> torch.Tensor: batch_size, _, seq_len_q, _ = query.shape _, _, seq_len_kv, _ = key.shape @@ -715,9 +885,6 @@ def _sage_varlen_attention( if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - if enable_gqa: - raise NotImplementedError("GQA is not yet supported.") - if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index abb7efcf3b85..cadcedb98a14 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,7 @@ is_bitsandbytes_version, is_bs4_available, is_cosmos_guardrail_available, + is_flash_attn_3_available, is_flash_attn_available, is_flash_attn_version, is_flax_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 811a87eb3db5..06fa85862c38 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -221,6 +221,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") +_flash_attn_3_available, _flash_attn_version = _is_package_available("flash_attn_3") def is_torch_available(): @@ -387,6 +388,10 @@ def is_flash_attn_available(): return _flash_attn_available +def is_flash_attn_3_available(): + return _flash_attn_3_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the From 29d8fccd7127a94e3c595c6ac1b457e0c648c8f3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 4 Jun 2025 12:27:28 +0200 Subject: [PATCH 10/10] update --- src/diffusers/models/attention_dispatch.py | 12 ++- src/diffusers/models/attention_processor.py | 99 ++++++++++++++++--- src/diffusers/models/modeling_utils.py | 46 +++++++++ .../transformers/transformer_lumina2.py | 6 +- .../models/transformers/transformer_wan.py | 18 +++- src/diffusers/utils/import_utils.py | 2 +- 6 files changed, 165 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index fee1184876b4..c6c78a44a632 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -198,9 +198,19 @@ def dispatch_attention_fn( scale: Optional[float] = None, enable_gqa: bool = False, attention_kwargs: Optional[Dict[str, Any]] = None, + *, + backend: Optional[AttentionBackendName] = None, ) -> torch.Tensor: attention_kwargs = attention_kwargs or {} - backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() + + if backend is None: + # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment + # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager + backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() + else: + backend_name = AttentionBackendName(backend) + backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + kwargs = { "query": query, "key": key, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7b13bc136125..802a31e101fb 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -997,6 +997,8 @@ def forward( class MochiAttnProcessor2_0: """Attention processor used in Mochi.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") @@ -1074,7 +1076,9 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) - attn_output = dispatch_attention_fn(valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False) + attn_output = dispatch_attention_fn( + valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) valid_sequence_length = attn_output.size(2) attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) attn_outputs.append(attn_output) @@ -2274,6 +2278,8 @@ def __call__( class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -2339,7 +2345,13 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -2366,6 +2378,8 @@ def __call__( class FluxAttnProcessor2_0_NPU: """Attention processor used typically in processing the SD3-like self-attention projections.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -2448,7 +2462,9 @@ def __call__( inner_precise=0, )[0] else: - hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2472,6 +2488,8 @@ def __call__( class FusedFluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -2542,7 +2560,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2567,6 +2587,8 @@ def __call__( class FusedFluxAttnProcessor2_0_NPU: """Attention processor used typically in processing the SD3-like self-attention projections.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -2653,7 +2675,9 @@ def __call__( inner_precise=0, )[0] else: - hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2678,6 +2702,8 @@ def __call__( class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): """Flux Attention processor for IP-Adapter.""" + _attention_backend = None + def __init__( self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None ): @@ -2775,7 +2801,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2806,7 +2834,13 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 current_ip_hidden_states = dispatch_attention_fn( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ip_query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim @@ -2825,6 +2859,8 @@ class CogVideoXAttnProcessor2_0: query and key vectors, but does not include spatial normalization. """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -2872,7 +2908,13 @@ def __call__( key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -2894,6 +2936,8 @@ class FusedCogVideoXAttnProcessor2_0: query and key vectors, but does not include spatial normalization. """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -2943,7 +2987,13 @@ def __call__( key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -3129,9 +3179,10 @@ class AttnProcessorNPU: Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is not significant. - """ + _attention_backend = None + def __init__(self): if not is_torch_npu_available(): raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") @@ -3216,7 +3267,13 @@ def __call__( else: # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -3243,6 +3300,8 @@ class AttnProcessor2_0: Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -3310,7 +3369,13 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -3553,6 +3618,8 @@ class MochiVaeAttnProcessor2_0: Attention processor used in Mochi VAE. """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -3614,7 +3681,13 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=attn.is_causal, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79fb9..cfd2495a0863 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -599,6 +599,52 @@ def enable_group_offload( low_cpu_mem_usage=low_cpu_mem_usage, ) + def set_attention_backend(self, backend: str) -> None: + """ + Set the attention backend for the model. + + Args: + backend (`str`): + The name of the backend to set. Must be one of the available backends defined in + `AttentionBackendName`. Available backends can be found in + `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product + attention as backend. + """ + from .attention_dispatch import AttentionBackendName + from .attention_processor import Attention, MochiAttention + + backend = backend.lower() + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend) + attention_classes = (Attention, MochiAttention) + + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = backend + + def reset_attention_backend(self) -> None: + """ + Resets the attention backend for the model. Following calls to `forward` will use the environment default or + the torch native scaled dot product attention. + """ + from .attention_processor import Attention, MochiAttention + + attention_classes = (Attention, MochiAttention) + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = None + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index ffa72294ade5..78a26b89907c 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -72,6 +72,8 @@ class Lumina2AttnProcessor2_0: used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors. """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -138,7 +140,9 @@ def __call__( key = key.transpose(1, 2) value = value.transpose(1, 2) - hidden_states = dispatch_attention_fn(query, key, value, attn_mask=attention_mask, scale=softmax_scale) + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, scale=softmax_scale, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index c73b2b9d1c17..84a5e732bc0e 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -36,6 +36,8 @@ class WanAttnProcessor2_0: + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") @@ -92,13 +94,25 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) hidden_states_img = dispatch_attention_fn( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) hidden_states = dispatch_attention_fn( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 06fa85862c38..4fe71801e8f9 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -221,7 +221,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") -_flash_attn_3_available, _flash_attn_version = _is_package_available("flash_attn_3") +_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") def is_torch_available():