Skip to content

Commit 41fcab6

Browse files
theissenhelenJesperDramschRilwan-AdewoyingmertesJPXKQX
authored
feat: make flash attention configurable (ecmwf#60)
* feat: FlashMultiHeadSelfAttention * Chore/multiple fixes ci precommit (ecmwf#41) * fix: change pre-cmmit autoupdate schedule to monthly * fix: change the merge strategy for Changelog to Union * fix: add .envrc to .gitignore * ci: ignore pre-commit-config and readthedocs for changelog updates * ci: fix to correct hpc workflow call * fix: update precommit config * chore: update pre-commits * feat: add codeowners file * chore: update dependencies * ci: add hpc-config * docs: changelog * fix: respond to review comments --------- Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int> * 11 add configurability to dropout in multiheadselfattention module (ecmwf#12) * feat: add configurability to dropout in MultiHeadSelfAttention Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> * test: adjust to dropout_p * doc: update changelog * Feature/integrate reusable workflows (ecmwf#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (ecmwf#20) * ci: inherit pypi publish flow (ecmwf#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * Update CHANGELOG.md to KeepChangelog format * [pre-commit.ci] pre-commit autoupdate (ecmwf#25) updates: - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](psf/black-pre-commit-mirror@24.4.2...24.8.0) - [github.com/astral-sh/ruff-pre-commit: v0.4.6 → v0.6.2](astral-sh/ruff-pre-commit@v0.4.6...v0.6.2) - [github.com/tox-dev/pyproject-fmt: 2.1.3 → 2.2.1](tox-dev/pyproject-fmt@2.1.3...2.2.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Ci/changelog-release-updater (ecmwf#26) * ci: add changelof release updater * docs: update changelog * Feature/integrate reusable workflows (ecmwf#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (ecmwf#20) * ci: inherit pypi publish flow (ecmwf#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * Update CHANGELOG.md to KeepChangelog format * Ci/changelog-release-updater (ecmwf#26) * ci: add changelof release updater * docs: update changelog --------- Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> Co-authored-by: Gert Mertes <gert.mertes@ecmwf.int> Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * chore!: drop support for scaled_dot_product_attention * feat: add softcap * test: add softcap xfail for MultiHeadSelfAttention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: flash attention lazy import * feat: make alibi slopes configurable * chore(deps): add flash-attn * feat: use scaled_dot_product as default * feat: make alibi_slope cinfigurable in block, chunk processor * chore(deps): remove flash-attn * feat: get alibi_slopes * docs: update docstrings * fix: bias shape * fix: softcap optional * fix: import annotations from future * fix: annotation error * docs: update changelog * fix: type annotation * feat: catch low flash-attn version * feat: FlashMultiHeadSelfAttention * Chore/multiple fixes ci precommit (ecmwf#41) * fix: change pre-cmmit autoupdate schedule to monthly * fix: change the merge strategy for Changelog to Union * fix: add .envrc to .gitignore * ci: ignore pre-commit-config and readthedocs for changelog updates * ci: fix to correct hpc workflow call * fix: update precommit config * chore: update pre-commits * feat: add codeowners file * chore: update dependencies * ci: add hpc-config * docs: changelog * fix: respond to review comments --------- Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int> * 11 add configurability to dropout in multiheadselfattention module (ecmwf#12) * feat: add configurability to dropout in MultiHeadSelfAttention Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> * test: adjust to dropout_p * doc: update changelog * Feature/integrate reusable workflows (ecmwf#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (ecmwf#20) * ci: inherit pypi publish flow (ecmwf#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * Update CHANGELOG.md to KeepChangelog format * [pre-commit.ci] pre-commit autoupdate (ecmwf#25) updates: - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](psf/black-pre-commit-mirror@24.4.2...24.8.0) - [github.com/astral-sh/ruff-pre-commit: v0.4.6 → v0.6.2](astral-sh/ruff-pre-commit@v0.4.6...v0.6.2) - [github.com/tox-dev/pyproject-fmt: 2.1.3 → 2.2.1](tox-dev/pyproject-fmt@2.1.3...2.2.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Ci/changelog-release-updater (ecmwf#26) * ci: add changelof release updater * docs: update changelog * Feature/integrate reusable workflows (ecmwf#16) * ci: add public pr label * ci: add readthedocs update check * ci: add downstream ci * ci: add ci-config * chore(deps): remove unused dependency * docs: update changelog * ci: switch to main * chore: changelog 0.2.1 * Update error messages from invalid sub_graph in model instantiation (ecmwf#20) * ci: inherit pypi publish flow (ecmwf#17) * ci: inherit pypi publish flow Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * docs: add to changelog * fix: typo in reusable workflow * fix: another typo * chore: bump actions/setup-python to v5 * ci: run downstream-ci for changes in src and tests * docs: update changelog --------- Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int> * Update CHANGELOG.md to KeepChangelog format * Ci/changelog-release-updater (ecmwf#26) * ci: add changelof release updater * docs: update changelog --------- Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> Co-authored-by: Gert Mertes <gert.mertes@ecmwf.int> Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * chore!: drop support for scaled_dot_product_attention * feat: add softcap * test: add softcap xfail for MultiHeadSelfAttention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: flash attention lazy import * feat: make alibi slopes configurable * chore(deps): add flash-attn * feat: use scaled_dot_product as default * feat: make alibi_slope cinfigurable in block, chunk processor * chore(deps): remove flash-attn * feat: get alibi_slopes * docs: update docstrings * fix: bias shape * fix: softcap optional * fix: import annotations from future * fix: annotation error * docs: update changelog * fix: type annotation * feat: catch low flash-attn version * feat: attention wrapper * fix: remove duplicate version check * added flex attn wrapper * fix: alibi_slopes unassigned * adding causal wip * added flex attn module * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bump min torch version to be able to use Flex Attn * added input parameter checks * precommit fix * fix: typo * test: adjust tests * fix: no self.use_alibi_slopes * fix: use_alibi_slope default to false * feat: Add sliding window support for TorchAttention via mask * fix: set default flash_attention * fix: pytest * fix: tests * docs: improve docstrings in MultiHeadSelfAttention * fix: error instead of SystemExit * chore: refactor SDPAAttention update_mask method * feat: add missing pytest.ini * chore: remove explicit float typing * support running without window size * test: sepa:rate test for sdpa and flex attention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added asserts and tests for flex attn * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: embed_dim / num_heads >=16 * test: fix tests to account for embed_dim constraints * fix tests * chore: remove debugging code * consitency change * chore(configs): add attention_implementation * Update models/src/anemoi/models/layers/attention.py Co-authored-by: Harrison Cook <Harrison.cook@ecmwf.int> * Update models/src/anemoi/models/layers/attention.py Co-authored-by: Harrison Cook <Harrison.cook@ecmwf.int> * fix: address comments * chore: remove flex_attention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test: fix merge * fix test to address breaking change from torch 2.6 * remove flex_attention references --------- Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int> Co-authored-by: Rilwan (Akanni) Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> Co-authored-by: Gert Mertes <gert.mertes@ecmwf.int> Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Cathal OBrien <cathal.obrien@ecmwf.int> Co-authored-by: japols <jan.polster@ecmwf.int> Co-authored-by: Harrison Cook <Harrison.cook@ecmwf.int> Co-authored-by: anaprietonem <ana.prietonemesio@ecmwf.int>
1 parent 54e42cf commit 41fcab6

File tree

13 files changed

+326
-42
lines changed

13 files changed

+326
-42
lines changed

graphs/tests/test_create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,6 @@ def test_generate_graph(self, config_file: tuple[Path, str], mock_grids_path: tu
5151

5252
if graph_path is not None:
5353
assert graph_path.exists()
54-
graph_saved = torch.load(graph_path)
54+
graph_saved = torch.load(graph_path, weights_only=False)
5555
assert graph.node_types == graph_saved.node_types
5656
assert graph.edge_types == graph_saved.edge_types

models/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,24 @@ Keep it human-readable, your future self will thank you!
5454

5555
### Added
5656

57+
- CI workflow to update the changelog on release
58+
- add configurability of flash attention (#47)
59+
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
5760
- CI workflow to update the changelog on release
5861
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.
62+
- Codeowners file
63+
- Pygrep precommit hooks
64+
- Docsig precommit hooks
65+
- Changelog merge strategy
66+
5967

6068
### Changed
6169

6270
- Update CI to inherit from common infrastructue reusable workflows
6371
- run downstream-ci only when src and tests folders have changed
6472
- New error messages for wrongs graphs.
6573
- Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45)
74+
- Bugfixes for CI
6675

6776
### Removed
6877

models/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ dependencies = [
4646
"anemoi-utils>=0.1.9",
4747
"einops>=0.6.1",
4848
"hydra-core>=1.3",
49-
"torch>=2.2",
49+
"torch>=2.5",
5050
"torch-geometric>=2.3,<2.5",
5151
]
5252
optional-dependencies.all = [ ]

models/pytest.ini

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[pytest]
2+
markers =
3+
data_dependent: marks tests depending on data (deselect with '-m "not data_dependent"')
4+
auth: marks tests that require authentication (deselect with '-m "not auth"')
5+
gpu: marks tests that require a GPU (deselect with '-m "not gpu"')
6+
7+
tmp_path_retention_policy = none

models/src/anemoi/models/layers/attention.py

Lines changed: 211 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,19 @@
88
# nor does it submit to any jurisdiction.
99

1010

11+
from __future__ import annotations
12+
1113
import logging
14+
import math
1215
from typing import Optional
1316

1417
import einops
18+
import torch
19+
from packaging import version
1520
from torch import Tensor
1621
from torch import nn
1722
from torch.distributed.distributed_c10d import ProcessGroup
1823

19-
try:
20-
from flash_attn import flash_attn_func as attn_func
21-
except ImportError:
22-
from torch.nn.functional import scaled_dot_product_attention as attn_func
23-
24-
_FLASH_ATTENTION_AVAILABLE = False
25-
else:
26-
_FLASH_ATTENTION_AVAILABLE = True
27-
2824
from anemoi.models.distributed.transformer import shard_heads
2925
from anemoi.models.distributed.transformer import shard_sequence
3026
from anemoi.utils.config import DotDict
@@ -33,7 +29,12 @@
3329

3430

3531
class MultiHeadSelfAttention(nn.Module):
36-
"""Multi Head Self Attention Pytorch Layer."""
32+
"""Multi Head Self Attention Pytorch Layer
33+
34+
allows for three different attention implementations:
35+
- scaled dot product attention, see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
36+
- flash attention, see https://github.com/Dao-AILab/flash-attention
37+
"""
3738

3839
def __init__(
3940
self,
@@ -44,32 +45,89 @@ def __init__(
4445
is_causal: bool = False,
4546
window_size: Optional[int] = None,
4647
dropout_p: float = 0.0,
48+
attention_implementation: str = "flash_attention",
49+
softcap: Optional[float] = None,
50+
use_alibi_slopes: bool = False,
4751
):
52+
"""Initialize MultiHeadSelfAttention.
53+
54+
For the flash attention implementation, two additional parameters are available: softcap, use_alibi_slopes
55+
56+
softcap: Softcapping prevents the logits from growing excessively large
57+
58+
use_alibi_slopes: Adds bias of `(-alibi_slope * |i + seqlen_k - seqlen_q - j|)` to the attention score of
59+
query i and key j, where alibi_slope is calculated using get_alibi_slopes
60+
61+
Parameters
62+
----------
63+
num_heads : int
64+
number of heads
65+
embed_dim : int
66+
embedding dimension
67+
bias : bool, optional
68+
bias, by default False
69+
is_causal : bool, optional
70+
apply causal attention mask, by default False
71+
window_size : Optional[int], optional
72+
window_size, by default None
73+
dropout_p : float, optional
74+
dropout probability, by default 0.0
75+
attention_implementation: str, optional
76+
A predefined string which selects which underlying attention
77+
implementation, by default "flash_attention"
78+
softcap : float, optional
79+
Anything > 0 activates softcapping attention, by default None
80+
use_alibi_slopes : bool, optional
81+
Adds bias
82+
"""
4883
super().__init__()
4984

5085
assert (
5186
embed_dim % num_heads == 0
5287
), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"
5388

89+
self.attention_implementation = attention_implementation
90+
self.use_alibi_slopes = use_alibi_slopes
91+
5492
self.num_heads = num_heads
5593
self.embed_dim = embed_dim
5694
self.head_dim = embed_dim // num_heads # q k v
57-
self.window_size = (window_size, window_size) # flash attention
95+
self.window_size = window_size
5896
self.dropout_p = dropout_p
5997
self.is_causal = is_causal
98+
self.softcap = softcap
99+
100+
self.set_attention_function()
101+
102+
if self.use_alibi_slopes:
103+
self.alibi_slopes = get_alibi_slopes(num_heads)
104+
assert self.alibi_slopes.shape[0] == num_heads, "Error: Number of alibi_slopes must match number of heads"
105+
else:
106+
self.alibi_slopes = None
60107

61108
linear = layer_kernels["Linear"]
62109
self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias)
63-
self.attention = attn_func
64-
65-
if not _FLASH_ATTENTION_AVAILABLE:
66-
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")
67110

68111
self.projection = linear(embed_dim, embed_dim, bias=True)
69112

113+
def set_attention_function(self):
114+
attn_funcs = {
115+
"flash_attention": FlashAttentionWrapper,
116+
"scaled_dot_product_attention": SDPAAttentionWrapper,
117+
}
118+
assert (
119+
self.attention_implementation in attn_funcs
120+
), f"{self.attention_implementation} not supported. \
121+
Please change model.processor.attention_implementation to one of: {attn_funcs.keys()}"
122+
LOGGER.info(f"Using {self.attention_implementation}")
123+
124+
# initalise the attn func here
125+
self.attention = attn_funcs[self.attention_implementation]()
126+
70127
def forward(
71128
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
72129
) -> Tensor:
130+
73131
query, key, value = self.lin_qkv(x).chunk(3, -1)
74132

75133
if model_comm_group:
@@ -92,24 +150,151 @@ def forward(
92150
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
93151
dropout_p = self.dropout_p if self.training else 0.0
94152

95-
if _FLASH_ATTENTION_AVAILABLE:
96-
query, key, value = (
97-
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
153+
out = self.attention(
154+
query,
155+
key,
156+
value,
157+
batch_size,
158+
causal=False,
159+
window_size=self.window_size,
160+
dropout_p=dropout_p,
161+
softcap=self.softcap,
162+
alibi_slopes=self.alibi_slopes,
163+
)
164+
165+
out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
166+
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")
167+
168+
out = self.projection(out)
169+
170+
return out
171+
172+
173+
class SDPAAttentionWrapper(nn.Module):
174+
"""Wrapper for Pytorch scaled dot product attention"""
175+
176+
def __init__(self):
177+
super().__init__()
178+
179+
from torch.nn.functional import scaled_dot_product_attention
180+
181+
self.attention = scaled_dot_product_attention
182+
self.mask = None
183+
self.window_size = None
184+
185+
def update_mask(self, seq_len, window_size: int, device: str):
186+
187+
self.mask = (
188+
torch.abs(
189+
torch.arange(seq_len, device=device).unsqueeze(0) - torch.arange(seq_len, device=device).unsqueeze(1)
98190
)
99-
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)
100-
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
101-
else:
191+
<= window_size
192+
)
193+
194+
def forward(
195+
self,
196+
query,
197+
key,
198+
value,
199+
batch_size: int,
200+
causal=False,
201+
window_size=None,
202+
dropout_p=0.0,
203+
softcap=None,
204+
alibi_slopes=None,
205+
):
206+
if softcap is not None:
207+
NotImplementedError(
208+
"Softcap not supported by Pytorchs SDPA. please switch to flash attention or disable softcap."
209+
)
210+
if alibi_slopes is not None:
211+
NotImplementedError(
212+
"Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes."
213+
)
214+
215+
sequence_len = query.shape[-2]
216+
217+
if window_size is not None and (self.mask is None or tuple(self.mask.shape) != (sequence_len, sequence_len)):
218+
self.update_mask(sequence_len, window_size=window_size, device=query.device)
219+
220+
with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]):
102221
out = self.attention(
103222
query,
104223
key,
105224
value,
106-
is_causal=False,
225+
attn_mask=self.mask,
226+
is_causal=causal,
107227
dropout_p=dropout_p,
108-
) # expects (batch heads grid variable) format
228+
)
109229

110-
out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
111-
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")
230+
return out
112231

113-
out = self.projection(out)
114232

233+
class FlashAttentionWrapper(nn.Module):
234+
"""Wrapper for Flash attention."""
235+
236+
def __init__(self):
237+
super().__init__()
238+
try:
239+
import flash_attn
240+
except ImportError:
241+
raise ImportError("Error: Flash-attn not installed. Please install flash-attn to use Flash Attention")
242+
243+
if version.parse(flash_attn.__version__) < version.parse("2.6.0"):
244+
raise RuntimeError("Error: Flash-attn version is too low. Update to 2.6.0 or higher.")
245+
else:
246+
self.attention = flash_attn.flash_attn_func
247+
248+
def forward(
249+
self,
250+
query,
251+
key,
252+
value,
253+
batch_size: int,
254+
causal: bool = False,
255+
window_size: int = None,
256+
dropout_p: float = 0.0,
257+
softcap: Optional[float] = None,
258+
alibi_slopes: torch.Tensor = None,
259+
):
260+
query, key, value = (
261+
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
262+
)
263+
264+
alibi_slopes = alibi_slopes.repeat(batch_size, 1).to(query.device) if alibi_slopes is not None else None
265+
266+
out = self.attention(
267+
query,
268+
key,
269+
value,
270+
causal=False,
271+
window_size=(window_size, window_size),
272+
dropout_p=dropout_p,
273+
softcap=softcap,
274+
alibi_slopes=alibi_slopes,
275+
)
276+
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
115277
return out
278+
279+
280+
def get_alibi_slopes(num_heads: int) -> Tensor:
281+
"""Calculates linearly decreasing slopes for alibi attention.
282+
283+
Parameters
284+
----------
285+
num_heads : int
286+
number of attention heads
287+
288+
Returns
289+
-------
290+
Tensor
291+
aLiBi slopes
292+
"""
293+
n = 2 ** math.floor(math.log2(num_heads))
294+
slope_0 = 2 ** (-8 / n)
295+
alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n))
296+
if n < num_heads:
297+
slope_hat_0 = 2 ** (-4 / n)
298+
alibi_slopes_hat = torch.pow(slope_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
299+
alibi_slopes = torch.cat([alibi_slopes, alibi_slopes_hat])
300+
return alibi_slopes

models/src/anemoi/models/layers/block.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def __init__(
7171
window_size: int,
7272
layer_kernels: DotDict,
7373
dropout_p: float = 0.0,
74+
attention_implementation: str = "flash_attention",
75+
softcap: float = None,
76+
use_alibi_slopes: bool = None,
7477
):
7578
super().__init__()
7679

@@ -91,6 +94,9 @@ def __init__(
9194
is_causal=False,
9295
dropout_p=dropout_p,
9396
layer_kernels=layer_kernels,
97+
attention_implementation=attention_implementation,
98+
softcap=softcap,
99+
use_alibi_slopes=use_alibi_slopes,
94100
)
95101

96102
self.mlp = nn.Sequential(

0 commit comments

Comments
 (0)