Skip to content

Commit dee38ed

Browse files
committed
feat(modules): add xformers eff self-attention
1 parent c418bdf commit dee38ed

File tree

3 files changed

+52
-19
lines changed

3 files changed

+52
-19
lines changed

cellseg_models_pytorch/modules/self_attention/exact_attention.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import torch
22
import torch.nn as nn
33

4+
try:
5+
from xformers.ops import memory_efficient_attention
6+
except ModuleNotFoundError:
7+
pass
8+
49

510
class ExactSelfAttention(nn.Module):
611
def __init__(
@@ -15,26 +20,37 @@ def __init__(
1520
Three variants:
1621
- basic self-attention implementation with torch.matmul O(N^2)
1722
- slice-attention - Computes the attention matrix in slices to save mem.
18-
- flash attention from xformers package:
19-
Citation..
23+
- `xformers.ops.memory_efficient_attention` from xformers package.
2024
2125
Parameters
2226
----------
2327
head_dim : int
2428
Out dim per attention head.
2529
self_attention : str, default="basic"
26-
One of ("basic", "flash", "sliced"). Basic is the normal O(N^2)
27-
self attention. "flash" is the flash attention (by xformes library),
28-
"slice" is self attention implemented with sliced matmul operation
29-
to save memory.
30+
One of ("basic", "flash", "sliced", "memeff").
31+
"basic": the normal O(N^2) self attention.
32+
"flash": the flash attention (by xformers library),
33+
"slice": batch sliced attention operation to save mem.
34+
"memeff" xformers.memory_efficient_attention.
3035
num_heads : int, optional
3136
Number of heads. Used only if `slice_attention = True`.
3237
slice_size, int, optional
3338
The size of the slice. Used only if `slice_attention = True`.
39+
40+
Raises
41+
------
42+
- ValueError:
43+
- If illegal self attention method is given.
44+
- If `self_attention` is set to `slice` while `num_heads` | `slice_size`
45+
args are not given proper integer values.
46+
- If `self_attention` is set to `memeff` but cuda is not available.
47+
- ModuleNotFoundError:
48+
- If `self_attention` is set to `memeff` and `xformers` package is not
49+
installed
3450
"""
3551
super().__init__()
3652

37-
allowed = ("basic", "flash", "slice")
53+
allowed = ("basic", "flash", "slice", "memeff")
3854
if self_attention not in allowed:
3955
raise ValueError(
4056
f"Illegal exact self attention type given. Got: {self_attention}. "
@@ -59,6 +75,21 @@ def __init__(
5975
f"and `num_heads`: {num_heads}."
6076
)
6177

78+
if self_attention == "memeff":
79+
try:
80+
import xformers # noqa F401
81+
except ModuleNotFoundError:
82+
raise ModuleNotFoundError(
83+
"`self_attention` was set to `memeff`. The method requires the "
84+
"xformers package. See how to install xformers: "
85+
"https://github.com/facebookresearch/xformers"
86+
)
87+
if not torch.cuda.is_available():
88+
raise ValueError(
89+
"`self_attention` was set to `memeff`. The method is implemented "
90+
"with `xformers.memory_efficient_attention` that requires cuda."
91+
)
92+
6293
def _attention(
6394
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
6495
) -> torch.Tensor:
@@ -161,7 +192,9 @@ def forward(
161192
torch.Tensor:
162193
The self-attention matrix. Same shape as inputs.
163194
"""
164-
if self.self_attention == "flash":
195+
if self.self_attention == "memeff":
196+
attn = memory_efficient_attention(query, key, value)
197+
elif self.self_attention == "flash":
165198
raise NotImplementedError
166199
elif self.self_attention == "slice":
167200
attn = self._slice_attention(query, key, value)

cellseg_models_pytorch/modules/self_attention_modules.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ def __init__(
2929
query_dim : int
3030
The number of channels in the query. Typically: num_heads*head_dim
3131
self_attention : str, default="basic"
32-
One of ("basic", "flash", "sliced").
33-
"basic" is the normal O(N^2) self attention.
34-
"flash" is the flash attention (by xformers library),
35-
"slice" is self attention implemented with sliced matmul operation
36-
on the batch dimension to save memory.
32+
One of ("basic", "flash", "slice", "memeff").
33+
"basic": the normal O(N^2) self attention.
34+
"flash": the flash attention (by xformers library),
35+
"slice": batch sliced attention operation to save mem.
36+
"memeff" xformers.memory_efficient_attention.
3737
cross_attention_dim : int, optional
3838
Number of channels in the context tensor. Cross attention combines
3939
asymmetrically two separate embeddings (context and input embeddings).
@@ -204,11 +204,11 @@ def __init__(
204204
Parameters
205205
----------
206206
name : str
207-
One of ("basic", "flash", "sliced").
208-
"basic" is the normal O(N^2) self attention.
209-
"flash" is the flash attention (by xformers library),
210-
"slice" is self attention implemented with sliced matmul operation
211-
on the batch dimension to save memory.
207+
One of ("basic", "flash", "slice", "memeff").
208+
"basic": the normal O(N^2) self attention.
209+
"flash": the flash attention (by xformers library),
210+
"slice": batch sliced attention operation to save mem.
211+
"memeff" xformers.memory_efficient_attention.
212212
query_dim : int
213213
The number of channels in the query. Typically: num_heads*head_dim
214214
cross_attention_dim : int, optional

cellseg_models_pytorch/modules/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
block_types : Tuple[str, ...], default=("basic", "basic")
5151
The name of the SelfAttentionBlocks in the TransformerLayer.
5252
Length of the tuple has to equal `n_blocks`
53-
Allowed names: "basic". "slice", "flash".
53+
Allowed names: "basic". "slice", "flash", "memeff".
5454
dropouts : Tuple[float, ...], default=(False, False)
5555
Dropout probabilities for the SelfAttention blocks.
5656
biases : bool, default=(True, True)

0 commit comments

Comments
 (0)