1
1
import torch
2
2
import torch .nn as nn
3
3
4
+ try :
5
+ from xformers .ops import memory_efficient_attention
6
+ except ModuleNotFoundError :
7
+ pass
8
+
4
9
5
10
class ExactSelfAttention (nn .Module ):
6
11
def __init__ (
@@ -15,26 +20,37 @@ def __init__(
15
20
Three variants:
16
21
- basic self-attention implementation with torch.matmul O(N^2)
17
22
- slice-attention - Computes the attention matrix in slices to save mem.
18
- - flash attention from xformers package:
19
- Citation..
23
+ - `xformers.ops.memory_efficient_attention` from xformers package.
20
24
21
25
Parameters
22
26
----------
23
27
head_dim : int
24
28
Out dim per attention head.
25
29
self_attention : str, default="basic"
26
- One of ("basic", "flash", "sliced"). Basic is the normal O(N^2)
27
- self attention. "flash" is the flash attention (by xformes library),
28
- "slice" is self attention implemented with sliced matmul operation
29
- to save memory.
30
+ One of ("basic", "flash", "sliced", "memeff").
31
+ "basic": the normal O(N^2) self attention.
32
+ "flash": the flash attention (by xformers library),
33
+ "slice": batch sliced attention operation to save mem.
34
+ "memeff" xformers.memory_efficient_attention.
30
35
num_heads : int, optional
31
36
Number of heads. Used only if `slice_attention = True`.
32
37
slice_size, int, optional
33
38
The size of the slice. Used only if `slice_attention = True`.
39
+
40
+ Raises
41
+ ------
42
+ - ValueError:
43
+ - If illegal self attention method is given.
44
+ - If `self_attention` is set to `slice` while `num_heads` | `slice_size`
45
+ args are not given proper integer values.
46
+ - If `self_attention` is set to `memeff` but cuda is not available.
47
+ - ModuleNotFoundError:
48
+ - If `self_attention` is set to `memeff` and `xformers` package is not
49
+ installed
34
50
"""
35
51
super ().__init__ ()
36
52
37
- allowed = ("basic" , "flash" , "slice" )
53
+ allowed = ("basic" , "flash" , "slice" , "memeff" )
38
54
if self_attention not in allowed :
39
55
raise ValueError (
40
56
f"Illegal exact self attention type given. Got: { self_attention } . "
@@ -59,6 +75,21 @@ def __init__(
59
75
f"and `num_heads`: { num_heads } ."
60
76
)
61
77
78
+ if self_attention == "memeff" :
79
+ try :
80
+ import xformers # noqa F401
81
+ except ModuleNotFoundError :
82
+ raise ModuleNotFoundError (
83
+ "`self_attention` was set to `memeff`. The method requires the "
84
+ "xformers package. See how to install xformers: "
85
+ "https://github.com/facebookresearch/xformers"
86
+ )
87
+ if not torch .cuda .is_available ():
88
+ raise ValueError (
89
+ "`self_attention` was set to `memeff`. The method is implemented "
90
+ "with `xformers.memory_efficient_attention` that requires cuda."
91
+ )
92
+
62
93
def _attention (
63
94
self , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
64
95
) -> torch .Tensor :
@@ -161,7 +192,9 @@ def forward(
161
192
torch.Tensor:
162
193
The self-attention matrix. Same shape as inputs.
163
194
"""
164
- if self .self_attention == "flash" :
195
+ if self .self_attention == "memeff" :
196
+ attn = memory_efficient_attention (query , key , value )
197
+ elif self .self_attention == "flash" :
165
198
raise NotImplementedError
166
199
elif self .self_attention == "slice" :
167
200
attn = self ._slice_attention (query , key , value )
0 commit comments