Skip to content

Commit 5619963

Browse files
committed
feat(modules): Sep att-mat-comp type & att-method
1 parent bb50359 commit 5619963

File tree

5 files changed

+101
-24
lines changed

5 files changed

+101
-24
lines changed

cellseg_models_pytorch/modules/base_modules.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from .act import ACT_LOOKUP
55
from .conv import CONV_LOOKUP
66
from .norm import NORM_LOOKUP
7+
from .self_attention import SELFATT_LOOKUP
78
from .upsample import UP_LOOKUP
89

9-
__all__ = ["Activation", "Norm", "Up", "Conv", "Identity"]
10+
__all__ = ["Activation", "Norm", "Up", "Conv", "Identity", "MultiHeadSelfAttention"]
1011

1112

1213
class Identity(nn.Module):
@@ -176,3 +177,44 @@ def __init__(self, name: str, **kwargs) -> None:
176177
def forward(self, x: torch.Tensor) -> torch.Tensor:
177178
"""Forward pass for the convolution function."""
178179
return self.conv(x)
180+
181+
182+
class MultiHeadSelfAttention(nn.Module):
183+
def __init__(self, name: str, **kwargs) -> None:
184+
"""Multi-head self-attention wrapper class.
185+
186+
Parameters:
187+
-----------
188+
name : str
189+
Name of the mhsa method.
190+
191+
Raises
192+
------
193+
ValueError: if the mhsa method name is illegal.
194+
"""
195+
super().__init__()
196+
197+
allowed = list(SELFATT_LOOKUP.keys())
198+
if name not in allowed:
199+
raise ValueError(
200+
"Illegal multi-head attention method given. "
201+
f"Allowed: {allowed}. Got: '{name}'"
202+
)
203+
204+
try:
205+
self.att = SELFATT_LOOKUP[name](**kwargs)
206+
except Exception as e:
207+
raise Exception(
208+
"Encountered an error when trying to init convolution function: "
209+
f"MultiHeadSelfAttention(name='{name}'): {e.__class__.__name__}: {e}"
210+
)
211+
212+
def forward(
213+
self,
214+
query: torch.Tensor,
215+
key: torch.Tensor,
216+
value: torch.Tensor,
217+
**kwargs,
218+
) -> torch.Tensor:
219+
"""Forward pass for the convolution function."""
220+
return self.att(query, key, value, **kwargs)

cellseg_models_pytorch/modules/self_attention_modules.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.nn as nn
55

6-
from .self_attention import ExactSelfAttention
6+
from .base_modules import MultiHeadSelfAttention
77

88
__all__ = ["SelfAttention", "SelfAttentionBlock"]
99

@@ -12,7 +12,8 @@ class SelfAttention(nn.Module):
1212
def __init__(
1313
self,
1414
query_dim: int,
15-
self_attention: str = "basic",
15+
name: str = "exact",
16+
how: str = "basic",
1617
cross_attention_dim: int = None,
1718
num_heads: int = 8,
1819
head_dim: int = 64,
@@ -29,12 +30,16 @@ def __init__(
2930
----------
3031
query_dim : int
3132
The number of channels in the query. Typically: num_heads*head_dim
32-
self_attention : str, default="basic"
33-
One of ("basic", "flash", "slice", "memeff").
33+
name : str
34+
Name of the attention method. One of ("exact", "linformer").
35+
how : str, default="basic"
36+
How to compute the self-attention matrix.
37+
One of ("basic", "flash", "slice", "memeff", "slice-memeff").
3438
"basic": the normal O(N^2) self attention.
3539
"flash": the flash attention (by xformers library),
3640
"slice": batch sliced attention operation to save mem.
37-
"memeff" xformers.memory_efficient_attention.
41+
"memeff": xformers.memory_efficient_attention.
42+
"slice-memeff": Conmbine slicing and memory_efficient_attention.
3843
cross_attention_dim : int, optional
3944
Number of channels in the context tensor. Cross attention combines
4045
asymmetrically two separate embeddings (context and input embeddings).
@@ -51,27 +56,29 @@ def __init__(
5156
slice_size : int, default=4
5257
Slice size for sliced self-attention. This is used only if
5358
`self_attention = "slice"`.
59+
**kwargs:
60+
Extra key-word arguments for the MHSA-module
5461
"""
5562
super().__init__()
5663
self.out_channels = query_dim
64+
self.num_heads = num_heads
5765
proj_channels = head_dim * num_heads
5866

5967
# cross attention dim
6068
if cross_attention_dim is None:
6169
cross_attention_dim = query_dim
6270

63-
self.scale = head_dim**-0.5
64-
self.num_heads = num_heads
65-
6671
self.to_q = nn.Linear(query_dim, proj_channels, bias=bias)
6772
self.to_k = nn.Linear(cross_attention_dim, proj_channels, bias=bias)
6873
self.to_v = nn.Linear(cross_attention_dim, proj_channels, bias=bias)
6974

70-
self.self_attn = ExactSelfAttention(
75+
self.self_attn = MultiHeadSelfAttention(
76+
name=name,
7177
head_dim=head_dim,
72-
self_attention=self_attention,
7378
num_heads=self.num_heads,
79+
how=how,
7480
slice_size=slice_size,
81+
**kwargs,
7582
)
7683

7784
self.to_out = nn.Linear(proj_channels, query_dim)
@@ -187,8 +194,9 @@ def forward(
187194
class SelfAttentionBlock(nn.Module):
188195
def __init__(
189196
self,
190-
name: str,
197+
how: str,
191198
query_dim: int,
199+
name: str = "exact",
192200
cross_attention_dim: int = None,
193201
num_heads: int = 8,
194202
head_dim: int = 64,
@@ -206,11 +214,15 @@ def __init__(
206214
Parameters
207215
----------
208216
name : str
209-
One of ("basic", "flash", "slice", "memeff").
217+
Name of the attention method. One of ("exact", "linformer").
218+
how : str, default="basic"
219+
How to compute the self-attention matrix.
220+
One of ("basic", "flash", "slice", "memeff", "slice-memeff").
210221
"basic": the normal O(N^2) self attention.
211222
"flash": the flash attention (by xformers library),
212223
"slice": batch sliced attention operation to save mem.
213-
"memeff" xformers.memory_efficient_attention.
224+
"memeff": xformers.memory_efficient_attention.
225+
"slice-memeff": Conmbine slicing and memory_efficient_attention.
214226
query_dim : int
215227
The number of channels in the query. Typically: num_heads*head_dim
216228
cross_attention_dim : int, optional
@@ -234,14 +246,16 @@ def __init__(
234246

235247
self.norm = nn.LayerNorm(query_dim)
236248
self.att = SelfAttention(
237-
self_attention=name,
249+
name=name,
250+
how=how,
238251
query_dim=query_dim,
239252
cross_attention_dim=cross_attention_dim,
240253
head_dim=head_dim,
241254
num_heads=num_heads,
242255
dropout=dropout,
243256
bias=bias,
244257
slice_size=slice_size,
258+
**kwargs,
245259
)
246260

247261
def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:

cellseg_models_pytorch/modules/transformers.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def __init__(
1919
head_dim: int = 64,
2020
cross_attention_dim: int = None,
2121
n_blocks: int = 2,
22-
block_types: Tuple[str, ...] = ("basic", "basic"),
22+
block_types: Tuple[str, ...] = ("exact", "exact"),
23+
computation_types: Tuple[str, ...] = ("basic", "basic"),
2324
dropouts: Tuple[float, ...] = (0.0, 0.0),
2425
biases: Tuple[bool, ...] = (False, False),
2526
activation: str = "star_relu",
@@ -47,10 +48,14 @@ def __init__(
4748
set to None, no cross attention is applied.
4849
n_blocks : int, default=2
4950
Number of Multihead attention blocks in the transformer.
50-
block_types : Tuple[str, ...], default=("basic", "basic")
51+
block_types : Tuple[str, ...], default=("exact", "exact")
5152
The name of the SelfAttentionBlocks in the TransformerLayer.
52-
Length of the tuple has to equal `n_blocks`
53-
Allowed names: "basic". "slice", "flash", "memeff".
53+
Length of the tuple has to equal `n_blocks`.
54+
Allowed names: ("exact", "linformer").
55+
computation_types : Tuple[str, ...], default=("basic", "basic")
56+
The way of computing the attention matrices in the SelfAttentionBlocks
57+
in the TransformerLayer. Length of the tuple has to equal `n_blocks`
58+
Allowed styles: "basic". "slice", "flash", "memeff", "slice_memeff".
5459
dropouts : Tuple[float, ...], default=(False, False)
5560
Dropout probabilities for the SelfAttention blocks.
5661
biases : bool, default=(True, True)
@@ -87,11 +92,13 @@ def __init__(
8792
cross_attention_dim=cross_attention_dim,
8893
n_blocks=n_blocks,
8994
block_types=block_types,
95+
computation_types=computation_types,
9096
dropouts=dropouts,
9197
biases=biases,
9298
activation=activation,
9399
slice_size=slice_size,
94100
mlp_ratio=mlp_ratio,
101+
**kwargs,
95102
)
96103

97104
self.proj_out = nn.Conv2d(
@@ -140,7 +147,8 @@ def __init__(
140147
cross_attention_dim: int = None,
141148
activation: str = "star_relu",
142149
n_blocks: int = 2,
143-
block_types: Tuple[str, ...] = ("basic", "basic"),
150+
block_types: Tuple[str, ...] = ("exact", "exact"),
151+
computation_types: Tuple[str, ...] = ("basic", "basic"),
144152
dropouts: Tuple[float, ...] = (0.0, 0.0),
145153
biases: Tuple[bool, ...] = (False, False),
146154
slice_size: int = 4,
@@ -171,10 +179,14 @@ def __init__(
171179
One of ("gelu", "geglu", "approximate_gelu", "star_relu").
172180
n_blocks : int, default=2
173181
Number of SelfAttentionBlocks used in this layer.
174-
block_types : Tuple[str, ...], default=("basic", "basic")
182+
block_types : Tuple[str, ...], default=("exact", "exact")
183+
The name of the SelfAttentionBlocks in the TransformerLayer.
184+
Length of the tuple has to equal `n_blocks`.
185+
Allowed names: ("exact", "linformer").
186+
computation_types : Tuple[str, ...], default=("basic", "basic")
175187
The name of the SelfAttentionBlocks in the TransformerLayer.
176188
Length of the tuple has to equal `n_blocks`
177-
Allowed names: "basic". "slice", "flash".
189+
Allowed styles: "basic". "slice", "flash", "memeff", "slice_memeff".
178190
dropouts : Tuple[float, ...], default=(False, False)
179191
Dropout probabilities for the SelfAttention blocks.
180192
biases : bool, default=(True, True)
@@ -186,7 +198,7 @@ def __init__(
186198
Multiplier that defines the out dimension of the final fc projection
187199
layer.
188200
**kwargs:
189-
Arbitrary key-word arguments (e.g. for activation function.).
201+
Arbitrary key-word arguments.
190202
191203
Raises
192204
------
@@ -213,13 +225,15 @@ def __init__(
213225

214226
att_block = SelfAttentionBlock(
215227
name=block_types[i],
228+
how=computation_types[i],
216229
query_dim=query_dim,
217230
num_heads=num_heads,
218231
head_dim=head_dim,
219232
cross_attention_dim=cross_dim,
220233
dropout=dropouts[i],
221234
biases=biases[i],
222235
slice_size=slice_size,
236+
**kwargs,
223237
)
224238
self.tr_blocks[f"transformer_{block_types[i]}_{i + 1}"] = att_block
225239

@@ -254,7 +268,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
254268
con = None
255269
if i == n_blocks:
256270
con = context
257-
print("context: ", con.shape)
271+
258272
x = tr_block(x, con)
259273

260274
return self.mlp(x) + x
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
## Refactor
2+
3+
- Added more verbose error messages for the abstract wrapper-modules in `modules.base_modules`
4+
- Added more verbose error catching for xformers.ops.memory_efficient_attention.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Features
2+
3+
- Add Linformer self-attention mechanism.

0 commit comments

Comments
 (0)