Skip to content

Commit ccbe465

Browse files
Cyrilvallezzucchini-nlp
authored andcommitted
Allow easy registration of custom attention functions (huggingface#36889)
* Update modeling_utils.py * style * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * add to init * Update modeling_utils.py * style * update * Update modeling_utils.py * Update modeling_utils.py * style * Add some doc * Update _toctree.yml * readd it for tgi/vllm compat * CIs * CIs
1 parent 40c7808 commit ccbe465

File tree

6 files changed

+171
-11
lines changed

6 files changed

+171
-11
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
title: The Transformer model family
3030
- local: attention
3131
title: Attention mechanisms
32+
- local: attention_interface
33+
title: Customizing attention function
3234
title: Models
3335
- sections:
3436
- local: fast_tokenizers

docs/source/en/attention_interface.md

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
11+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
12+
rendered properly in your Markdown viewer.
13+
14+
-->
15+
16+
# Attention Interface
17+
18+
This page describes how to use the `AttentionInterface` in order to register custom attention functions to use with
19+
supported models.
20+
21+
## Customizing attention function
22+
23+
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
24+
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
25+
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
26+
as well as `eager`, which is simple matrix multiplication without any optimization on top.
27+
This is the setting you can usually choose when instantiating a model:
28+
29+
```python
30+
from transformers import AutoModelForCausalLM
31+
32+
model_id = "meta-llama/Llama-3.2-1B
33+
34+
# Here, using flash attention as an example
35+
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
36+
```
37+
38+
But what if you wanted to create your own attention function? Or simply play around with existing ones, adding
39+
a few statements here and there? You can now do so with the `AttentionInterface`! Here is an example:
40+
41+
```python
42+
from transformers import AutoModelForCausalLM, AttentionInterface
43+
from transformers.integrations.sdpa_attention import sdpa_attention_forward
44+
import torch
45+
46+
model_id = "meta-llama/Llama-3.2-1B
47+
48+
def my_new_sdpa(*args, **kwargs):
49+
print("I just entered the attention computation")
50+
return sdpa_attention_forward(*args, **kwargs)
51+
52+
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
53+
54+
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa")
55+
# Try running the forward with the new attention function
56+
model(torch.ones(1, 5, dtype=int))
57+
```
58+
59+
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times.
60+
61+
## Dynamically switching attention function
62+
63+
You could dynamically change the model's attention function as well, by overriding the `config._attn_implementation` field:
64+
65+
```python
66+
# Back to use original sdpa implementation
67+
model.config._attn_implementation = "sdpa"
68+
69+
model(torch.ones(1, 5, dtype=int))
70+
```
71+
72+
and it will stop printing the statements, as it now uses the `sdpa` attention.
73+
This allows to quickly change attention function, without needing to reload the model!
74+
75+
## What about new args needed in my custom function?
76+
77+
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
78+
`AttentionInterface` propagates kwargs all the way to the Attention layers, and to the attention function used. That way,
79+
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
80+
81+
```python
82+
from transformers import AutoModelForCausalLM, AttentionInterface
83+
from transformers.integrations.sdpa_attention import sdpa_attention_forward
84+
import torch
85+
86+
def custom_attention(
87+
module: torch.nn.Module, # required arg
88+
query: torch.Tensor, # required arg
89+
key: torch.Tensor, # required arg
90+
value: torch.Tensor, # required arg
91+
attention_mask: Optional[torch.Tensor], # required arg
92+
a_new_kwargs = None, # You can now add as many kwargs as you need
93+
another_new_kwargs = None, # You can now add as many kwargs as you need
94+
**kwargs, # You need to accept **kwargs as models will pass other args
95+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]
96+
... # do your magic!
97+
return attn_output, attn_weights # attn_weights are optional here
98+
99+
AttentionInterface.register("custom", custom_attention)
100+
101+
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
102+
# Forward pass with the new kwargs
103+
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
104+
```
105+
106+
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!

docs/source/en/internal/modeling_utils.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ rendered properly in your Markdown viewer.
1616

1717
# Custom Layers and Utilities
1818

19-
This page lists all the custom layers used by the library, as well as the utility functions it provides for modeling.
19+
This page lists all the custom layers used by the library, as well as the utility functions and classes it provides for modeling.
2020

2121
Most of those are only useful if you are studying the code of the models in the library.
2222

23+
## Attention Functions
24+
25+
[[autodoc]] AttentionInterface
26+
- register
2327

2428
## Pytorch custom modules
2529

src/transformers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,7 +1482,7 @@
14821482
_import_structure["modeling_flash_attention_utils"] = []
14831483
_import_structure["modeling_outputs"] = []
14841484
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
1485-
_import_structure["modeling_utils"] = ["PreTrainedModel"]
1485+
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
14861486

14871487
# PyTorch models structure
14881488

@@ -6727,7 +6727,7 @@
67276727
model_addition_debugger_context,
67286728
)
67296729
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
6730-
from .modeling_utils import PreTrainedModel
6730+
from .modeling_utils import AttentionInterface, PreTrainedModel
67316731
from .models.albert import (
67326732
AlbertForMaskedLM,
67336733
AlbertForMultipleChoice,

src/transformers/modeling_utils.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import tempfile
2929
import warnings
3030
from collections import defaultdict
31+
from collections.abc import MutableMapping
3132
from contextlib import contextmanager
3233
from dataclasses import dataclass
3334
from enum import Enum
@@ -2081,9 +2082,10 @@ def _autoset_attn_implementation(
20812082
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
20822083
)
20832084

2084-
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
2085-
"eager"
2086-
] + list(ALL_ATTENTION_FUNCTIONS.keys()):
2085+
if (
2086+
not isinstance(config._attn_implementation, dict)
2087+
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
2088+
):
20872089
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
20882090
if cls._supports_flash_attn_2:
20892091
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
@@ -2148,7 +2150,7 @@ def _autoset_attn_implementation(
21482150
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
21492151
)
21502152
torch.backends.cuda.enable_flash_sdp(False)
2151-
elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
2153+
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
21522154
config._attn_implementation = requested_attn_implementation
21532155
elif isinstance(requested_attn_implementation, dict):
21542156
config._attn_implementation = None
@@ -5891,12 +5893,51 @@ def get_disk_only_shard_files(device_map, weight_map):
58915893
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
58925894

58935895

5894-
ALL_ATTENTION_FUNCTIONS: Dict[str, Callable] = {}
5896+
class AttentionInterface(MutableMapping):
5897+
"""
5898+
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
5899+
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
5900+
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
5901+
"""
58955902

5896-
ALL_ATTENTION_FUNCTIONS.update(
5897-
{
5903+
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
5904+
# a new instance is created (in order to locally override a given function)
5905+
_global_mapping = {
58985906
"flash_attention_2": flash_attention_forward,
58995907
"flex_attention": flex_attention_forward,
59005908
"sdpa": sdpa_attention_forward,
59015909
}
5902-
)
5910+
5911+
def __init__(self):
5912+
self._local_mapping = {}
5913+
5914+
def __getitem__(self, key):
5915+
# First check if instance has a local override
5916+
if key in self._local_mapping:
5917+
return self._local_mapping[key]
5918+
return self._global_mapping[key]
5919+
5920+
def __setitem__(self, key, value):
5921+
# Allow local update of the default functions without impacting other instances
5922+
self._local_mapping.update({key: value})
5923+
5924+
def __delitem__(self, key):
5925+
del self._local_mapping[key]
5926+
5927+
def __iter__(self):
5928+
# Ensure we use all keys, with the overwritten ones on top
5929+
return iter(self._global_mapping.update(self._local_mapping))
5930+
5931+
def __len__(self):
5932+
return len(self._global_mapping.keys() | self._local_mapping.keys())
5933+
5934+
@classmethod
5935+
def register(cls, key: str, value: Callable):
5936+
cls._global_mapping.update({key: value})
5937+
5938+
def valid_keys(self) -> List[str]:
5939+
return list(self._global_mapping.keys() | self._local_mapping.keys())
5940+
5941+
5942+
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
5943+
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()

src/transformers/utils/dummy_pt_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,13 @@ def model_addition_debugger_context(*args, **kwargs):
549549
ROPE_INIT_FUNCTIONS = None
550550

551551

552+
class AttentionInterface(metaclass=DummyObject):
553+
_backends = ["torch"]
554+
555+
def __init__(self, *args, **kwargs):
556+
requires_backends(self, ["torch"])
557+
558+
552559
class PreTrainedModel(metaclass=DummyObject):
553560
_backends = ["torch"]
554561

0 commit comments

Comments
 (0)