Skip to content

Commit 0d8c05e

Browse files
committed
feat(modules): add cross-att long-skip module
1 parent 2b40ec8 commit 0d8c05e

File tree

3 files changed

+145
-1
lines changed

3 files changed

+145
-1
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Any, Dict, Optional, Tuple
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from ...modules import Transformer2D
7+
from ...modules.patch_embeddings import ContiguousEmbed
8+
9+
__all__ = ["CrossAttentionSkip"]
10+
11+
12+
class CrossAttentionSkip(nn.Module):
13+
def __init__(
14+
self,
15+
stage_ix: int,
16+
in_channels: int,
17+
skip_channels: Tuple[int, ...] = None,
18+
num_heads: int = 8,
19+
head_dim: int = 64,
20+
n_blocks: int = 1,
21+
block_types: Tuple[str, ...] = ("exact",),
22+
computation_types: Tuple[str, ...] = ("basic",),
23+
dropouts: Tuple[float, ...] = (0.0,),
24+
biases: Tuple[bool, ...] = (False,),
25+
layer_scales: Tuple[bool, ...] = (False,),
26+
activation: str = "star_relu",
27+
mlp_ratio: int = 2,
28+
slice_size: int = 4,
29+
patch_embed_kwargs: Optional[Dict[str, Any]] = None,
30+
**kwargs
31+
) -> None:
32+
"""Skip connection (U-Net-like) via cross-attention.
33+
34+
Does the long skip connection through a cross-attention transformer rather than
35+
merging or summing the skip features to the upsampled decoder feature-map.
36+
37+
Parameters
38+
----------
39+
stage_ix : int
40+
Index number signalling the current decoder stage
41+
in_channels : int, default=None
42+
The number of channels in the input tensor.
43+
skip_channels : Tuple[int, ...]
44+
Tuple of the number of channels in the encoder stages.
45+
Order is bottom up. This list does not include the final
46+
bottleneck stage out channels.
47+
num_heads : int, default=8
48+
Number of heads in multi-head attention.
49+
head_dim : int, default=64
50+
The out dim of the heads.
51+
n_blocks : int, default=1
52+
Number of SelfAttentionBlocks used in this layer.
53+
block_types : Tuple[str, ...], default=("exact", )
54+
The name of the SelfAttentionBlocks in the TransformerLayer.
55+
Length of the tuple has to equal `n_blocks`
56+
Allowed names: "basic". "slice", "flash".
57+
computation_types : Tuple[str, ...], default=("basic", )
58+
The way of computing the attention matrices in the SelfAttentionBlocks
59+
in the TransformerLayer. Length of the tuple has to equal `n_blocks`
60+
Allowed styles: "basic". "slice", "flash", "memeff", "slice_memeff".
61+
dropouts : Tuple[float, ...], default=(False, )
62+
Dropout probabilities for the SelfAttention blocks.
63+
biases : bool, default=(True, True)
64+
Include bias terms in the SelfAttention blocks.
65+
layer_scales : bool, default=(False, )
66+
Learnable layer weights for the self-attention matrix.
67+
activation : str, default="star_relu"
68+
The activation function applied at the end of the transformer layer fc.
69+
One of ("geglu", "approximate_gelu", "star_relu").
70+
mlp_ratio : int, default=4
71+
Multiplier that defines the out dimension of the final fc projection
72+
layer.
73+
slice_size : int, default=4
74+
Slice size for sliced self-attention. This is used only if
75+
`name = "slice"` for a SelfAttentionBlock.
76+
patch_embed_kwargs: Dict[str, Any], optional
77+
Extra key-word arguments for the patch embedding module. See the
78+
`ContiguousEmbed` module for more info.
79+
"""
80+
super().__init__()
81+
self.in_channels = in_channels
82+
self.stage_ix = stage_ix
83+
84+
if stage_ix < len(skip_channels):
85+
context_channels = skip_channels[stage_ix]
86+
87+
self.context_patch_embed = ContiguousEmbed(
88+
in_channels=context_channels,
89+
patch_size=1,
90+
num_heads=num_heads,
91+
head_dim=head_dim,
92+
normalization="gn",
93+
norm_kwargs={"num_features": context_channels},
94+
**patch_embed_kwargs if patch_embed_kwargs is not None else {},
95+
)
96+
97+
self.tranformer = Transformer2D(
98+
in_channels=in_channels,
99+
cross_attentions_dims=self.context_patch_embed.proj_dim,
100+
num_heads=num_heads,
101+
head_dim=head_dim,
102+
n_blocks=n_blocks,
103+
block_types=block_types,
104+
computation_types=computation_types,
105+
dropouts=dropouts,
106+
biases=biases,
107+
layer_scales=layer_scales,
108+
activation=activation,
109+
slice_size=slice_size,
110+
mlp_ratio=mlp_ratio,
111+
patch_embed_kwargs=patch_embed_kwargs,
112+
**kwargs,
113+
)
114+
115+
@property
116+
def out_channels(self) -> int:
117+
"""Out channels."""
118+
return self.in_channels
119+
120+
def forward(
121+
self, x: torch.Tensor, skips: Tuple[torch.Tensor], **kwargs
122+
) -> torch.Tensor:
123+
"""Forward pass of the skip connection."""
124+
if self.stage_ix < len(skips):
125+
context = skips[self.stage_ix] # (B, C, H, W)
126+
127+
# embed context for cross-attm transformer: (B, H'*W', num_heads*head_dim)
128+
context = self.context_patch_embed(context)
129+
130+
x = self.tranformer(x, context=context) # (B, C, H, W)
131+
132+
return x

cellseg_models_pytorch/decoders/long_skips/longskip_module.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch.nn as nn
55

66
from ...modules import Identity
7+
from .cross_attn_skip import CrossAttentionSkip
78
from .unet import UnetSkip
89
from .unet3p import Unet3pSkip
910
from .unetpp import UnetppSkip
@@ -16,6 +17,7 @@
1617
"unet3p": Unet3pSkip,
1718
"unet3p-lite": Unet3pSkip,
1819
"unetpp": UnetppSkip,
20+
"cross-attn": CrossAttentionSkip,
1921
}
2022

2123

@@ -44,7 +46,13 @@ def __init__(self, name: str, **kwargs) -> None:
4446
if name is not None:
4547
if name == "unet3p-lite":
4648
kwargs["lite_version"] = True
47-
self.skip = LONGSKIP_LOOKUP[name](**kwargs)
49+
try:
50+
self.skip = LONGSKIP_LOOKUP[name](**kwargs)
51+
except Exception as e:
52+
raise Exception(
53+
"Encountered an error when trying to init long-skip module: "
54+
f"LongSkip(name='{name}'): {e.__class__.__name__}: {e}"
55+
)
4856
else:
4957
self.skip = Identity()
5058

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
## Features
2+
3+
- Add support for model intialization from yaml-file in `MultiTaskUnet`.
4+
- Add a new cross-attention long-skip module. Works with `long_skip='cross-attn'`

0 commit comments

Comments
 (0)