|
38 | 38 | from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
39 | 39 |
|
40 | 40 |
|
41 |
| -logger = get_logger(__name__) # pylint: disable=invalid-name |
42 |
| - |
43 |
| - |
44 |
| -if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): |
| 41 | +_REQUIRED_FLASH_VERSION = "2.6.3" |
| 42 | +_REQUIRED_SAGE_VERSION = "2.1.1" |
| 43 | +_REQUIRED_FLEX_VERSION = "2.5.0" |
| 44 | +_REQUIRED_XLA_VERSION = "2.2" |
| 45 | +_REQUIRED_XFORMERS_VERSION = "0.0.29" |
| 46 | + |
| 47 | +_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) |
| 48 | +_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() |
| 49 | +_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) |
| 50 | +_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) |
| 51 | +_CAN_USE_NPU_ATTN = is_torch_npu_available() |
| 52 | +_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) |
| 53 | +_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) |
| 54 | + |
| 55 | + |
| 56 | +if _CAN_USE_FLASH_ATTN: |
45 | 57 | from flash_attn import flash_attn_func, flash_attn_varlen_func
|
46 | 58 | else:
|
47 |
| - logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") |
48 | 59 | flash_attn_func = None
|
49 | 60 | flash_attn_varlen_func = None
|
50 | 61 |
|
51 | 62 |
|
52 |
| -if is_flash_attn_3_available(): |
| 63 | +if _CAN_USE_FLASH_ATTN_3: |
53 | 64 | from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
54 | 65 | from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
55 | 66 | else:
|
56 | 67 | flash_attn_3_func = None
|
57 | 68 | flash_attn_3_varlen_func = None
|
58 | 69 |
|
59 | 70 |
|
60 |
| -if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): |
| 71 | +if _CAN_USE_SAGE_ATTN: |
61 | 72 | from sageattention import (
|
62 | 73 | sageattn,
|
63 | 74 | sageattn_qk_int8_pv_fp8_cuda,
|
|
67 | 78 | sageattn_varlen,
|
68 | 79 | )
|
69 | 80 | else:
|
70 |
| - logger.warning( |
71 |
| - "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." |
72 |
| - ) |
73 | 81 | sageattn = None
|
74 | 82 | sageattn_qk_int8_pv_fp16_cuda = None
|
75 | 83 | sageattn_qk_int8_pv_fp16_triton = None
|
|
78 | 86 | sageattn_varlen = None
|
79 | 87 |
|
80 | 88 |
|
81 |
| -if is_torch_version(">=", "2.5.0"): |
| 89 | +if _CAN_USE_FLEX_ATTN: |
82 | 90 | # We cannot import the flex_attention function from the package directly because it is expected (from the
|
83 | 91 | # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
84 | 92 | # compiled function.
|
85 | 93 | import torch.nn.attention.flex_attention as flex_attention
|
86 | 94 |
|
87 | 95 |
|
88 |
| -if is_torch_npu_available(): |
| 96 | +if _CAN_USE_NPU_ATTN: |
89 | 97 | from torch_npu import npu_fusion_attention
|
90 | 98 | else:
|
91 | 99 | npu_fusion_attention = None
|
92 | 100 |
|
93 | 101 |
|
94 |
| -if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): |
| 102 | +if _CAN_USE_XLA_ATTN: |
95 | 103 | from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
96 | 104 | else:
|
97 | 105 | xla_flash_attention = None
|
98 | 106 |
|
99 | 107 |
|
100 |
| -if is_xformers_available() and is_xformers_version(">=", "0.0.29"): |
| 108 | +if _CAN_USE_XFORMERS_ATTN: |
101 | 109 | import xformers.ops as xops
|
102 | 110 | else:
|
103 |
| - logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") |
104 | 111 | xops = None
|
105 | 112 |
|
106 | 113 |
|
| 114 | +logger = get_logger(__name__) # pylint: disable=invalid-name |
| 115 | + |
107 | 116 | # TODO(aryan): Add support for the following:
|
108 | 117 | # - Sage Attention++
|
109 | 118 | # - block sparse, radial and other attention methods
|
110 | 119 | # - CP with sage attention, flex, xformers, other missing backends
|
111 | 120 | # - Add support for normal and CP training with backends that don't support it yet
|
112 | 121 |
|
113 |
| - |
114 | 122 | _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
115 | 123 | _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
116 | 124 | _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
@@ -179,13 +187,16 @@ def list_backends(cls):
|
179 | 187 |
|
180 | 188 |
|
181 | 189 | @contextlib.contextmanager
|
182 |
| -def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): |
| 190 | +def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): |
183 | 191 | """
|
184 | 192 | Context manager to set the active attention backend.
|
185 | 193 | """
|
186 | 194 | if backend not in _AttentionBackendRegistry._backends:
|
187 | 195 | raise ValueError(f"Backend {backend} is not registered.")
|
188 | 196 |
|
| 197 | + backend = AttentionBackendName(backend) |
| 198 | + _check_attention_backend_requirements(backend) |
| 199 | + |
189 | 200 | old_backend = _AttentionBackendRegistry._active_backend
|
190 | 201 | _AttentionBackendRegistry._active_backend = backend
|
191 | 202 |
|
@@ -226,9 +237,10 @@ def dispatch_attention_fn(
|
226 | 237 | "dropout_p": dropout_p,
|
227 | 238 | "is_causal": is_causal,
|
228 | 239 | "scale": scale,
|
229 |
| - "enable_gqa": enable_gqa, |
230 | 240 | **attention_kwargs,
|
231 | 241 | }
|
| 242 | + if is_torch_version(">=", "2.5.0"): |
| 243 | + kwargs["enable_gqa"] = enable_gqa |
232 | 244 |
|
233 | 245 | if _AttentionBackendRegistry._checks_enabled:
|
234 | 246 | removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
|
@@ -305,6 +317,57 @@ def _check_shape(
|
305 | 317 | # ===== Helper functions =====
|
306 | 318 |
|
307 | 319 |
|
| 320 | +def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: |
| 321 | + if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: |
| 322 | + if not _CAN_USE_FLASH_ATTN: |
| 323 | + raise RuntimeError( |
| 324 | + f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." |
| 325 | + ) |
| 326 | + |
| 327 | + elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: |
| 328 | + if not _CAN_USE_FLASH_ATTN_3: |
| 329 | + raise RuntimeError( |
| 330 | + f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." |
| 331 | + ) |
| 332 | + |
| 333 | + elif backend in [ |
| 334 | + AttentionBackendName.SAGE, |
| 335 | + AttentionBackendName.SAGE_VARLEN, |
| 336 | + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, |
| 337 | + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, |
| 338 | + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, |
| 339 | + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, |
| 340 | + ]: |
| 341 | + if not _CAN_USE_SAGE_ATTN: |
| 342 | + raise RuntimeError( |
| 343 | + f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." |
| 344 | + ) |
| 345 | + |
| 346 | + elif backend == AttentionBackendName.FLEX: |
| 347 | + if not _CAN_USE_FLEX_ATTN: |
| 348 | + raise RuntimeError( |
| 349 | + f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." |
| 350 | + ) |
| 351 | + |
| 352 | + elif backend == AttentionBackendName._NATIVE_NPU: |
| 353 | + if not _CAN_USE_NPU_ATTN: |
| 354 | + raise RuntimeError( |
| 355 | + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." |
| 356 | + ) |
| 357 | + |
| 358 | + elif backend == AttentionBackendName._NATIVE_XLA: |
| 359 | + if not _CAN_USE_XLA_ATTN: |
| 360 | + raise RuntimeError( |
| 361 | + f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." |
| 362 | + ) |
| 363 | + |
| 364 | + elif backend == AttentionBackendName.XFORMERS: |
| 365 | + if not _CAN_USE_XFORMERS_ATTN: |
| 366 | + raise RuntimeError( |
| 367 | + f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." |
| 368 | + ) |
| 369 | + |
| 370 | + |
308 | 371 | @functools.lru_cache(maxsize=128)
|
309 | 372 | def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
310 | 373 | batch_size: int,
|
|
0 commit comments