|
| 1 | +<!--Copyright 2025 The HuggingFace Team. All rights reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
| 4 | +the License. You may obtain a copy of the License at |
| 5 | +
|
| 6 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +
|
| 8 | +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
| 9 | +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
| 10 | +
|
| 11 | +⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be |
| 12 | +rendered properly in your Markdown viewer. |
| 13 | +
|
| 14 | +--> |
| 15 | + |
| 16 | +# Attention Interface |
| 17 | + |
| 18 | +This page describes how to use the `AttentionInterface` in order to register custom attention functions to use with |
| 19 | +supported models. |
| 20 | + |
| 21 | +## Customizing attention function |
| 22 | + |
| 23 | +Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping. |
| 24 | +By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), |
| 25 | +[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) |
| 26 | +as well as `eager`, which is simple matrix multiplication without any optimization on top. |
| 27 | +This is the setting you can usually choose when instantiating a model: |
| 28 | + |
| 29 | +```python |
| 30 | +from transformers import AutoModelForCausalLM |
| 31 | + |
| 32 | +model_id = "meta-llama/Llama-3.2-1B |
| 33 | + |
| 34 | +# Here, using flash attention as an example |
| 35 | +model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2") |
| 36 | +``` |
| 37 | + |
| 38 | +But what if you wanted to create your own attention function? Or simply play around with existing ones, adding |
| 39 | +a few statements here and there? You can now do so with the `AttentionInterface`! Here is an example: |
| 40 | + |
| 41 | +```python |
| 42 | +from transformers import AutoModelForCausalLM, AttentionInterface |
| 43 | +from transformers.integrations.sdpa_attention import sdpa_attention_forward |
| 44 | +import torch |
| 45 | + |
| 46 | +model_id = "meta-llama/Llama-3.2-1B |
| 47 | + |
| 48 | +def my_new_sdpa(*args, **kwargs): |
| 49 | + print("I just entered the attention computation") |
| 50 | + return sdpa_attention_forward(*args, **kwargs) |
| 51 | + |
| 52 | +AttentionInterface.register("my_new_sdpa", my_new_sdpa) |
| 53 | + |
| 54 | +model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa") |
| 55 | +# Try running the forward with the new attention function |
| 56 | +model(torch.ones(1, 5, dtype=int)) |
| 57 | +``` |
| 58 | + |
| 59 | +You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times. |
| 60 | + |
| 61 | +## Dynamically switching attention function |
| 62 | + |
| 63 | +You could dynamically change the model's attention function as well, by overriding the `config._attn_implementation` field: |
| 64 | + |
| 65 | +```python |
| 66 | +# Back to use original sdpa implementation |
| 67 | +model.config._attn_implementation = "sdpa" |
| 68 | + |
| 69 | +model(torch.ones(1, 5, dtype=int)) |
| 70 | +``` |
| 71 | + |
| 72 | +and it will stop printing the statements, as it now uses the `sdpa` attention. |
| 73 | +This allows to quickly change attention function, without needing to reload the model! |
| 74 | + |
| 75 | +## What about new args needed in my custom function? |
| 76 | + |
| 77 | +But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the |
| 78 | +`AttentionInterface` propagates kwargs all the way to the Attention layers, and to the attention function used. That way, |
| 79 | +you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e. |
| 80 | + |
| 81 | +```python |
| 82 | +from transformers import AutoModelForCausalLM, AttentionInterface |
| 83 | +from transformers.integrations.sdpa_attention import sdpa_attention_forward |
| 84 | +import torch |
| 85 | + |
| 86 | +def custom_attention( |
| 87 | + module: torch.nn.Module, # required arg |
| 88 | + query: torch.Tensor, # required arg |
| 89 | + key: torch.Tensor, # required arg |
| 90 | + value: torch.Tensor, # required arg |
| 91 | + attention_mask: Optional[torch.Tensor], # required arg |
| 92 | + a_new_kwargs = None, # You can now add as many kwargs as you need |
| 93 | + another_new_kwargs = None, # You can now add as many kwargs as you need |
| 94 | + **kwargs, # You need to accept **kwargs as models will pass other args |
| 95 | +) -> Tuple[torch.Tensor, Optional[torch.Tensor]] |
| 96 | + ... # do your magic! |
| 97 | + return attn_output, attn_weights # attn_weights are optional here |
| 98 | + |
| 99 | +AttentionInterface.register("custom", custom_attention) |
| 100 | + |
| 101 | +model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom") |
| 102 | +# Forward pass with the new kwargs |
| 103 | +model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...) |
| 104 | +``` |
| 105 | + |
| 106 | +If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)! |
0 commit comments