1
1
import torch
2
- import torch .nn as nn
3
2
4
- try :
5
- from xformers .ops import memory_efficient_attention
6
- except ModuleNotFoundError :
7
- pass
3
+ from .attention_ops import compute_mha
4
+ from .base_attention import BaseSelfAttention
8
5
6
+ __all__ = ["ExactSelfAttention" ]
9
7
10
- class ExactSelfAttention (nn .Module ):
8
+
9
+ class ExactSelfAttention (BaseSelfAttention ):
11
10
def __init__ (
12
11
self ,
13
12
head_dim : int ,
14
- self_attention : str = "basic" ,
15
- num_heads : int = None ,
13
+ num_heads : int ,
14
+ how : str = "basic" ,
16
15
slice_size : int = None ,
16
+ ** kwargs ,
17
17
) -> None :
18
18
"""Compute exact attention.
19
19
20
- Three variants:
21
- - basic self-attention implementation with torch.matmul O(N^2)
22
- - slice-attention - Computes the attention matrix in slices to save mem.
23
- - `xformers.ops.memory_efficient_attention` from xformers package.
20
+ Four variants:
21
+ - basic: self-attention implementation with torch.matmul O(N^2)
22
+ - slice-attention: Computes the attention matrix in slices to save mem.
23
+ - memeff: `xformers.ops.memory_efficient_attention` from xformers package.
24
+ - slice-memeff-attention: Comnbines slice-attention and memeff
24
25
25
26
Parameters
26
27
----------
27
28
head_dim : int
28
29
Out dim per attention head.
29
- self_attention : str, default="basic"
30
- One of ("basic", "flash", "sliced", "memeff").
30
+ num_heads : int
31
+ Number of heads.
32
+ how : str, default="basic"
33
+ How to compute the self-attention matrix.
34
+ One of ("basic", "flash", "slice", "memeff", "slice-memeff").
31
35
"basic": the normal O(N^2) self attention.
32
36
"flash": the flash attention (by xformers library),
33
37
"slice": batch sliced attention operation to save mem.
34
- "memeff" xformers.memory_efficient_attention.
35
- num_heads : int, optional
36
- Number of heads. Used only if `slice_attention = True`.
38
+ "memeff": xformers.memory_efficient_attention.
39
+ "slice-memeff": Conmbine slicing and memory_efficient_attention.
37
40
slice_size, int, optional
38
- The size of the slice. Used only if `slice_attention = True `.
41
+ The size of the slice. Used only if `how in ('slice', 'slice_memeff) `.
39
42
40
43
Raises
41
44
------
42
45
- ValueError:
43
- - If illegal self attention method is given.
44
- - If `self_attention ` is set to `slice` while `num_heads` | `slice_size`
46
+ - If illegal self attention (`how`) method is given.
47
+ - If `how ` is set to `slice` while `num_heads` | `slice_size`
45
48
args are not given proper integer values.
46
- - If `self_attention` is set to `memeff` but cuda is not available.
49
+ - If `how` is set to `memeff` or `slice_memeff` but cuda is not
50
+ available.
47
51
- ModuleNotFoundError:
48
52
- If `self_attention` is set to `memeff` and `xformers` package is not
49
53
installed
50
54
"""
51
- super ().__init__ ()
52
-
53
- allowed = ("basic" , "flash" , "slice" , "memeff" )
54
- if self_attention not in allowed :
55
- raise ValueError (
56
- f"Illegal exact self attention type given. Got: { self_attention } . "
57
- f"Allowed: { allowed } ."
58
- )
59
-
60
- self .self_attention = self_attention
61
- self .head_dim = head_dim
62
- self .num_heads = num_heads
63
- self .scale = head_dim ** - 0.5
64
-
65
- # These are only used for slice attention.
66
- if self_attention == "slice" :
67
- # for slice_size > 0 the attention score computation
68
- # is split across the batch axis to save memory
69
- self .slice_size = slice_size
70
- self .proj_channels = self .head_dim * self .num_heads
71
- if any (s is None for s in (slice_size , num_heads )):
72
- raise ValueError (
73
- "If `slice_attention` is set to True, `slice_size`, `num_heads`, "
74
- f"need to be given integer values. Got: `slice_size`: { slice_size } "
75
- f"and `num_heads`: { num_heads } ."
76
- )
77
-
78
- if self_attention == "memeff" :
79
- try :
80
- import xformers # noqa F401
81
- except ModuleNotFoundError :
82
- raise ModuleNotFoundError (
83
- "`self_attention` was set to `memeff`. The method requires the "
84
- "xformers package. See how to install xformers: "
85
- "https://github.com/facebookresearch/xformers"
86
- )
87
- if not torch .cuda .is_available ():
88
- raise ValueError (
89
- "`self_attention` was set to `memeff`. The method is implemented "
90
- "with `xformers.memory_efficient_attention` that requires cuda."
91
- )
92
-
93
- def _attention (
94
- self , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
95
- ) -> torch .Tensor :
96
- """Compute exact self attention with torch. Complexity: O(N**2).
97
-
98
- Parameters
99
- ----------
100
- query : torch.Tensor
101
- Query tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
102
- key : torch.Tensor
103
- Key tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
104
- value : torch.Tensor
105
- Value tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
106
-
107
- Returns
108
- -------
109
- torch.Tensor:
110
- The self-attention matrix. Same shape as inputs.
111
- """
112
- scores = torch .matmul (query , key .transpose (- 1 , - 2 )) * self .scale
113
- probs = scores .softmax (dim = - 1 )
114
-
115
- # compute attention output
116
- return torch .matmul (probs , value )
117
-
118
- def _slice_attention (
119
- self , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
120
- ) -> torch .Tensor :
121
- """Compute exact attention in slices to save memory.
122
-
123
- NOTE: adapted from hugginface diffusers package. Their implementation
124
- just dont handle the case where B // slize_size doesn't divide evenly.
125
- This would end up in zero-matrices at the final batch dimensions of
126
- the batched attention matrix.
127
-
128
- NOTE: The input is sliced in the batch dimension.
129
-
130
- Parameters
131
- ----------
132
- query : torch.Tensor
133
- Query tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
134
- key : torch.Tensor
135
- Key tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
136
- value : torch.Tensor
137
- Value tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
138
-
139
- Returns
140
- -------
141
- torch.Tensor:
142
- The self-attention matrix. Same shape as inputs.
143
- """
144
- B , seq_len = query .shape [:2 ]
145
- out = torch .zeros (
146
- (B , seq_len , self .proj_channels // self .num_heads ),
147
- device = query .device ,
148
- dtype = query .dtype ,
55
+ super ().__init__ (
56
+ head_dim = head_dim ,
57
+ num_heads = num_heads ,
58
+ how = how ,
59
+ slice_size = slice_size ,
149
60
)
150
61
151
- # get the modulo if B/slice_size is not evenly divisible.
152
- n_slices , mod = divmod (out .shape [0 ], self .slice_size )
153
- if mod != 0 :
154
- n_slices += 1
155
-
156
- it = list (range (n_slices ))
157
- for i in it :
158
- start = i * self .slice_size
159
- end = (i + 1 ) * self .slice_size
160
-
161
- if i == it [- 1 ]:
162
- end = start + mod
163
-
164
- attn_slice = torch .matmul (query [start :end ], key [start :end ].transpose (1 , 2 ))
165
- attn_slice *= self .scale
166
-
167
- attn_slice = attn_slice .softmax (dim = - 1 )
168
- attn_slice = torch .matmul (attn_slice , value [start :end ])
169
-
170
- out [start :end ] = attn_slice
171
-
172
- return out
173
-
174
62
def forward (
175
63
self , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
176
64
) -> torch .Tensor :
@@ -192,20 +80,15 @@ def forward(
192
80
torch.Tensor:
193
81
The self-attention matrix. Same shape as inputs.
194
82
"""
195
- if self .self_attention == "memeff" :
196
- if all ([query .is_cuda , key .is_cuda , value .is_cuda ]):
197
- attn = memory_efficient_attention (query , key , value )
198
- else :
199
- raise RuntimeError (
200
- "`xformers.ops.memory_efficient_attention` is only implemented "
201
- "for cuda. Make sure your inputs & model devices are set to cuda."
202
- )
203
- elif self .self_attention == "flash" :
204
- raise NotImplementedError
205
- elif self .self_attention == "slice" :
206
- attn = self ._slice_attention (query , key , value )
207
- else :
208
- attn = self ._attention (query , key , value )
83
+ attn = compute_mha (
84
+ query ,
85
+ key ,
86
+ value ,
87
+ self .how ,
88
+ slice_size = self .slice_size , # used only for slice-att
89
+ num_heads = self .num_heads , # used only for slice-att
90
+ proj_channels = self .proj_channels , # used only for slice-att
91
+ )
209
92
210
93
return attn
211
94
0 commit comments