Skip to content

Commit 8dcfac0

Browse files
committed
refactor: base class for self-att modules
1 parent f384717 commit 8dcfac0

File tree

3 files changed

+147
-157
lines changed

3 files changed

+147
-157
lines changed
Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
from .attention_ops import compute_mha, mha, slice_mha
12
from .exact_attention import ExactSelfAttention
3+
from .linformer import LinformerAttention
24

3-
__all__ = ["ExactSelfAttention"]
5+
SELFATT_LOOKUP = {
6+
"exact": ExactSelfAttention,
7+
"linformer": LinformerAttention,
8+
}
9+
10+
__all__ = [
11+
"ExactSelfAttention",
12+
"LinformerAttention",
13+
"mha",
14+
"slice_mha",
15+
"compute_mha",
16+
"SELFATT_LOOKUP",
17+
]
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
__all__ = ["BaseSelfAttention"]
5+
6+
7+
class BaseSelfAttention(nn.Module):
8+
def __init__(
9+
self,
10+
head_dim: int,
11+
num_heads: int,
12+
how: str = "basic",
13+
slice_size: int = None,
14+
**kwargs,
15+
) -> None:
16+
"""Initialize a base class for self-attention modules.
17+
18+
Four variants:
19+
- basic: self-attention implementation with torch.matmul O(N^2)
20+
- slice-attention: Computes the attention matrix in slices to save mem.
21+
- memeff: `xformers.ops.memory_efficient_attention` from xformers package.
22+
- slice-memeff-attention: Comnbines slice-attention and memeff
23+
24+
Parameters
25+
----------
26+
head_dim : int
27+
Out dim per attention head.
28+
num_heads : int
29+
Number of heads.
30+
how : str, default="basic"
31+
How to compute the self-attention matrix.
32+
One of ("basic", "flash", "slice", "memeff", "slice-memeff").
33+
"basic": the normal O(N^2) self attention.
34+
"flash": the flash attention (by xformers library),
35+
"slice": batch sliced attention operation to save mem.
36+
"memeff": xformers.memory_efficient_attention.
37+
"slice-memeff": Conmbine slicing and memory_efficient_attention.
38+
slice_size, int, optional
39+
The size of the slice. Used only if `how in ('slice', 'slice_memeff)`.
40+
41+
Raises
42+
------
43+
- ValueError:
44+
- If illegal self attention (`how`) method is given.
45+
- If `how` is set to `slice` while `num_heads` | `slice_size`
46+
args are not given proper integer values.
47+
- If `how` is set to `memeff` or `slice_memeff` but cuda is not
48+
available.
49+
- ModuleNotFoundError:
50+
- If `self_attention` is set to `memeff` and `xformers` package is not
51+
installed
52+
"""
53+
super().__init__()
54+
55+
allowed = ("basic", "flash", "slice", "memeff", "slice-memeff")
56+
if how not in allowed:
57+
raise ValueError(
58+
f"Illegal exact self attention type given. Got: {how}. "
59+
f"Allowed: {allowed}."
60+
)
61+
62+
self.how = how
63+
self.head_dim = head_dim
64+
self.num_heads = num_heads
65+
66+
if how == "slice":
67+
if any(s is None for s in (slice_size, num_heads)):
68+
raise ValueError(
69+
"If `how` is set to 'slice', `slice_size`, `num_heads`, "
70+
f"need to be given integer values. Got: `slice_size`: {slice_size} "
71+
f"and `num_heads`: {num_heads}."
72+
)
73+
74+
if how in ("memeff", "slice-memeff"):
75+
try:
76+
import xformers # noqa F401
77+
except ModuleNotFoundError:
78+
raise ModuleNotFoundError(
79+
"`self_attention` was set to `memeff`. The method requires the "
80+
"xformers package. See how to install xformers: "
81+
"https://github.com/facebookresearch/xformers"
82+
)
83+
if not torch.cuda.is_available():
84+
raise ValueError(
85+
f"`how` was set to {how}. This method for computing self attentiton"
86+
" is implemented with `xformers.memory_efficient_attention` that "
87+
"requires cuda."
88+
)
89+
90+
# for slice_size > 0 the attention score computation
91+
# is split across the batch axis to save memory
92+
self.slice_size = slice_size
93+
self.proj_channels = self.head_dim * self.num_heads

cellseg_models_pytorch/modules/self_attention/exact_attention.py

Lines changed: 39 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -1,176 +1,64 @@
11
import torch
2-
import torch.nn as nn
32

4-
try:
5-
from xformers.ops import memory_efficient_attention
6-
except ModuleNotFoundError:
7-
pass
3+
from .attention_ops import compute_mha
4+
from .base_attention import BaseSelfAttention
85

6+
__all__ = ["ExactSelfAttention"]
97

10-
class ExactSelfAttention(nn.Module):
8+
9+
class ExactSelfAttention(BaseSelfAttention):
1110
def __init__(
1211
self,
1312
head_dim: int,
14-
self_attention: str = "basic",
15-
num_heads: int = None,
13+
num_heads: int,
14+
how: str = "basic",
1615
slice_size: int = None,
16+
**kwargs,
1717
) -> None:
1818
"""Compute exact attention.
1919
20-
Three variants:
21-
- basic self-attention implementation with torch.matmul O(N^2)
22-
- slice-attention - Computes the attention matrix in slices to save mem.
23-
- `xformers.ops.memory_efficient_attention` from xformers package.
20+
Four variants:
21+
- basic: self-attention implementation with torch.matmul O(N^2)
22+
- slice-attention: Computes the attention matrix in slices to save mem.
23+
- memeff: `xformers.ops.memory_efficient_attention` from xformers package.
24+
- slice-memeff-attention: Comnbines slice-attention and memeff
2425
2526
Parameters
2627
----------
2728
head_dim : int
2829
Out dim per attention head.
29-
self_attention : str, default="basic"
30-
One of ("basic", "flash", "sliced", "memeff").
30+
num_heads : int
31+
Number of heads.
32+
how : str, default="basic"
33+
How to compute the self-attention matrix.
34+
One of ("basic", "flash", "slice", "memeff", "slice-memeff").
3135
"basic": the normal O(N^2) self attention.
3236
"flash": the flash attention (by xformers library),
3337
"slice": batch sliced attention operation to save mem.
34-
"memeff" xformers.memory_efficient_attention.
35-
num_heads : int, optional
36-
Number of heads. Used only if `slice_attention = True`.
38+
"memeff": xformers.memory_efficient_attention.
39+
"slice-memeff": Conmbine slicing and memory_efficient_attention.
3740
slice_size, int, optional
38-
The size of the slice. Used only if `slice_attention = True`.
41+
The size of the slice. Used only if `how in ('slice', 'slice_memeff)`.
3942
4043
Raises
4144
------
4245
- ValueError:
43-
- If illegal self attention method is given.
44-
- If `self_attention` is set to `slice` while `num_heads` | `slice_size`
46+
- If illegal self attention (`how`) method is given.
47+
- If `how` is set to `slice` while `num_heads` | `slice_size`
4548
args are not given proper integer values.
46-
- If `self_attention` is set to `memeff` but cuda is not available.
49+
- If `how` is set to `memeff` or `slice_memeff` but cuda is not
50+
available.
4751
- ModuleNotFoundError:
4852
- If `self_attention` is set to `memeff` and `xformers` package is not
4953
installed
5054
"""
51-
super().__init__()
52-
53-
allowed = ("basic", "flash", "slice", "memeff")
54-
if self_attention not in allowed:
55-
raise ValueError(
56-
f"Illegal exact self attention type given. Got: {self_attention}. "
57-
f"Allowed: {allowed}."
58-
)
59-
60-
self.self_attention = self_attention
61-
self.head_dim = head_dim
62-
self.num_heads = num_heads
63-
self.scale = head_dim**-0.5
64-
65-
# These are only used for slice attention.
66-
if self_attention == "slice":
67-
# for slice_size > 0 the attention score computation
68-
# is split across the batch axis to save memory
69-
self.slice_size = slice_size
70-
self.proj_channels = self.head_dim * self.num_heads
71-
if any(s is None for s in (slice_size, num_heads)):
72-
raise ValueError(
73-
"If `slice_attention` is set to True, `slice_size`, `num_heads`, "
74-
f"need to be given integer values. Got: `slice_size`: {slice_size} "
75-
f"and `num_heads`: {num_heads}."
76-
)
77-
78-
if self_attention == "memeff":
79-
try:
80-
import xformers # noqa F401
81-
except ModuleNotFoundError:
82-
raise ModuleNotFoundError(
83-
"`self_attention` was set to `memeff`. The method requires the "
84-
"xformers package. See how to install xformers: "
85-
"https://github.com/facebookresearch/xformers"
86-
)
87-
if not torch.cuda.is_available():
88-
raise ValueError(
89-
"`self_attention` was set to `memeff`. The method is implemented "
90-
"with `xformers.memory_efficient_attention` that requires cuda."
91-
)
92-
93-
def _attention(
94-
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
95-
) -> torch.Tensor:
96-
"""Compute exact self attention with torch. Complexity: O(N**2).
97-
98-
Parameters
99-
----------
100-
query : torch.Tensor
101-
Query tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
102-
key : torch.Tensor
103-
Key tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
104-
value : torch.Tensor
105-
Value tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
106-
107-
Returns
108-
-------
109-
torch.Tensor:
110-
The self-attention matrix. Same shape as inputs.
111-
"""
112-
scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
113-
probs = scores.softmax(dim=-1)
114-
115-
# compute attention output
116-
return torch.matmul(probs, value)
117-
118-
def _slice_attention(
119-
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
120-
) -> torch.Tensor:
121-
"""Compute exact attention in slices to save memory.
122-
123-
NOTE: adapted from hugginface diffusers package. Their implementation
124-
just dont handle the case where B // slize_size doesn't divide evenly.
125-
This would end up in zero-matrices at the final batch dimensions of
126-
the batched attention matrix.
127-
128-
NOTE: The input is sliced in the batch dimension.
129-
130-
Parameters
131-
----------
132-
query : torch.Tensor
133-
Query tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
134-
key : torch.Tensor
135-
Key tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
136-
value : torch.Tensor
137-
Value tensor. Shape: (B*num_heads, H*W, proj_dim//num_heads).
138-
139-
Returns
140-
-------
141-
torch.Tensor:
142-
The self-attention matrix. Same shape as inputs.
143-
"""
144-
B, seq_len = query.shape[:2]
145-
out = torch.zeros(
146-
(B, seq_len, self.proj_channels // self.num_heads),
147-
device=query.device,
148-
dtype=query.dtype,
55+
super().__init__(
56+
head_dim=head_dim,
57+
num_heads=num_heads,
58+
how=how,
59+
slice_size=slice_size,
14960
)
15061

151-
# get the modulo if B/slice_size is not evenly divisible.
152-
n_slices, mod = divmod(out.shape[0], self.slice_size)
153-
if mod != 0:
154-
n_slices += 1
155-
156-
it = list(range(n_slices))
157-
for i in it:
158-
start = i * self.slice_size
159-
end = (i + 1) * self.slice_size
160-
161-
if i == it[-1]:
162-
end = start + mod
163-
164-
attn_slice = torch.matmul(query[start:end], key[start:end].transpose(1, 2))
165-
attn_slice *= self.scale
166-
167-
attn_slice = attn_slice.softmax(dim=-1)
168-
attn_slice = torch.matmul(attn_slice, value[start:end])
169-
170-
out[start:end] = attn_slice
171-
172-
return out
173-
17462
def forward(
17563
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
17664
) -> torch.Tensor:
@@ -192,20 +80,15 @@ def forward(
19280
torch.Tensor:
19381
The self-attention matrix. Same shape as inputs.
19482
"""
195-
if self.self_attention == "memeff":
196-
if all([query.is_cuda, key.is_cuda, value.is_cuda]):
197-
attn = memory_efficient_attention(query, key, value)
198-
else:
199-
raise RuntimeError(
200-
"`xformers.ops.memory_efficient_attention` is only implemented "
201-
"for cuda. Make sure your inputs & model devices are set to cuda."
202-
)
203-
elif self.self_attention == "flash":
204-
raise NotImplementedError
205-
elif self.self_attention == "slice":
206-
attn = self._slice_attention(query, key, value)
207-
else:
208-
attn = self._attention(query, key, value)
83+
attn = compute_mha(
84+
query,
85+
key,
86+
value,
87+
self.how,
88+
slice_size=self.slice_size, # used only for slice-att
89+
num_heads=self.num_heads, # used only for slice-att
90+
proj_channels=self.proj_channels, # used only for slice-att
91+
)
20992

21093
return attn
21194

0 commit comments

Comments
 (0)