Skip to content

Commit f384717

Browse files
committed
refactor: self-att matrix computation to own file
1 parent 4a2aa89 commit f384717

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import torch
2+
3+
try:
4+
from xformers.ops import memory_efficient_attention
5+
6+
_has_xformers = True
7+
except ModuleNotFoundError:
8+
_has_xformers = False
9+
10+
11+
__all__ = ["multihead_attention", "slice_mha", "mha", "compute_mha"]
12+
13+
14+
def multihead_attention(
15+
query: torch.Tensor,
16+
key: torch.Tensor,
17+
value: torch.Tensor,
18+
scale: float = None,
19+
**kwargs,
20+
) -> torch.Tensor:
21+
"""Compute exact self attention with torch. Complexity: O(N**2).
22+
23+
Parameters
24+
----------
25+
query : torch.Tensor
26+
Query tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
27+
key : torch.Tensor
28+
Key tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
29+
value : torch.Tensor
30+
Value tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
31+
scale : float, optional
32+
Scaling factor for Q @ K'. If None, query.shape[-1]**-0.5 will
33+
be used
34+
35+
Returns
36+
-------
37+
torch.Tensor:
38+
The self-attention matrix. Same shape as inputs.
39+
"""
40+
if scale is None:
41+
scale = query.shape[-1] ** -0.5
42+
43+
scores = torch.matmul(query, key.transpose(-1, -2)) * scale
44+
probs = scores.softmax(dim=-1)
45+
46+
# compute attention output
47+
return torch.matmul(probs, value)
48+
49+
50+
def mha(
51+
query: torch.Tensor,
52+
key: torch.Tensor,
53+
value: torch.Tensor,
54+
att_type: str = "basic",
55+
**kwargs,
56+
) -> torch.Tensor:
57+
"""Compute exact self-attention.
58+
59+
I.e softmax(Q @ K'/sqrt(head_dim)) @ V
60+
61+
Parameters
62+
----------
63+
query : torch.Tensor
64+
Query tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
65+
key : torch.Tensor
66+
Key tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
67+
value : torch.Tensor
68+
Value tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
69+
att_type : str, default="basic"
70+
The type of the self-attention computation.
71+
One of: ("basic", "flash", "memeff").
72+
**kwargs:
73+
Extra key-word arguments for the mha computation.
74+
75+
Returns
76+
-------
77+
torch.Tensor:
78+
The self-attention matrix. Same shape as inputs.
79+
"""
80+
if att_type == "memeff":
81+
if _has_xformers:
82+
if all([query.is_cuda, key.is_cuda, value.is_cuda]):
83+
attn = memory_efficient_attention(query, key, value)
84+
else:
85+
raise RuntimeError(
86+
"`xformers.ops.memory_efficient_attention` is only implemented "
87+
"for cuda. Make sure your inputs & model devices are set to cuda."
88+
)
89+
else:
90+
raise ModuleNotFoundError(
91+
"Trying to use `memory_efficient_attention`. The method requires the "
92+
"xformers package. See how to install xformers: "
93+
"https://github.com/facebookresearch/xformers"
94+
)
95+
elif att_type == "flash":
96+
raise NotImplementedError
97+
elif att_type == "basic":
98+
attn = multihead_attention(query, key, value, **kwargs)
99+
else:
100+
raise ValueError(
101+
f"Unknown `att_type` given. Got: {att_type}. "
102+
f"Allowed: {('memeff', 'flash', 'basic')}"
103+
)
104+
105+
return attn
106+
107+
108+
def slice_mha(
109+
query: torch.Tensor,
110+
key: torch.Tensor,
111+
value: torch.Tensor,
112+
proj_channels: int,
113+
num_heads: int,
114+
slice_size: int = 4,
115+
att_type: str = "basic",
116+
**kwargs,
117+
) -> torch.Tensor:
118+
"""Compute exact attention in slices to save memory.
119+
120+
NOTE: adapted from hugginface diffusers package. Their implementation
121+
just dont handle the case where B // slize_size doesn't divide evenly.
122+
This would end up in zero-matrices at the final batch dimensions of
123+
the batched attention matrix.
124+
125+
NOTE: The input is sliced in the batch dimension.
126+
127+
Parameters
128+
----------
129+
query : torch.Tensor
130+
Query tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
131+
key : torch.Tensor
132+
Key tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
133+
value : torch.Tensor
134+
Value tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
135+
proj_channels : int
136+
Number of out channels in the token projections.
137+
num_heads : int
138+
Number of heads in the mha.
139+
slice_size : int, default=4
140+
The size of the batch dim slice.
141+
att_type : str, default="basic"
142+
The type of the self-attention computation.
143+
One of: ("memeff", "slice-memeff").
144+
145+
Returns
146+
-------
147+
torch.Tensor:
148+
The self-attention matrix. Same shape as inputs.
149+
"""
150+
allowed = ("slice", "slice-memeff")
151+
if att_type not in allowed:
152+
raise ValueError(
153+
f"Illegal slice-attention given. Got: {att_type}. Allowed: {allowed}."
154+
)
155+
156+
# parse the attention type arg
157+
a = att_type.split("-")
158+
if len(a) == 1:
159+
att_type = "basic"
160+
else:
161+
att_type = "memeff"
162+
163+
B, seq_len = query.shape[:2]
164+
out = torch.zeros(
165+
(B, seq_len, proj_channels // num_heads),
166+
device=query.device,
167+
dtype=query.dtype,
168+
)
169+
170+
# get the modulo if B/slice_size is not evenly divisible.
171+
n_slices, mod = divmod(out.shape[0], slice_size)
172+
if mod != 0:
173+
n_slices += 1
174+
175+
it = list(range(n_slices))
176+
for i in it:
177+
start = i * slice_size
178+
end = (i + 1) * slice_size
179+
180+
if i == it[-1]:
181+
end = start + mod
182+
183+
attn_slice = mha(
184+
query[start:end], key[start:end], value[start:end], att_type=att_type
185+
)
186+
187+
out[start:end] = attn_slice
188+
del attn_slice
189+
torch.cuda.empty_cache()
190+
191+
return out
192+
193+
194+
def compute_mha(
195+
query: torch.Tensor,
196+
key: torch.Tensor,
197+
value: torch.Tensor,
198+
how: str,
199+
**kwargs,
200+
) -> torch.Tensor:
201+
"""Wrap all the different attention matrix computation types under this.
202+
203+
Parameters
204+
----------
205+
query : torch.Tensor
206+
Query tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
207+
key : torch.Tensor
208+
Key tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
209+
value : torch.Tensor
210+
Value tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
211+
how : str, default="basic"
212+
How to compute the self-attention matrix.
213+
One of ("basic", "flash", "slice", "memeff", "slice-memeff").
214+
"basic": the normal O(N^2) self attention.
215+
"flash": the flash attention (by xformers library),
216+
"slice": batch sliced attention operation to save mem.
217+
"memeff": xformers.memory_efficient_attention.
218+
"slice-memeff": Conmbine slicing and memory_efficient_attention.
219+
**kwargs:
220+
Extra key-word args for the attention matrix computation.
221+
"""
222+
allowed = ("basic", "flash", "slice", "memeff", "slice-memeff")
223+
if how not in allowed:
224+
raise ValueError(
225+
f"Illegal exact self attention type given. Got: {how}. "
226+
f"Allowed: {allowed}."
227+
)
228+
229+
if how == "basic":
230+
attn = mha(query, key, value, att_type="basic", **kwargs)
231+
elif how == "memeff":
232+
attn = mha(query, key, value, att_type="memeff", **kwargs)
233+
elif how == "slice":
234+
attn = slice_mha(query, key, value, att_type="slice", **kwargs)
235+
elif how == "slice-memeff":
236+
attn = slice_mha(query, key, value, att_type="slice-memeff", **kwargs)
237+
elif how == "flash":
238+
raise NotImplementedError
239+
elif how == "slice-flash":
240+
raise NotImplementedError
241+
242+
return attn

0 commit comments

Comments
 (0)