Skip to content

Commit 60fb351

Browse files
committed
style: minor variable name fixes
1 parent cacaabe commit 60fb351

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

cellseg_models_pytorch/modules/mlp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(
1717
dropout: float = 0.0,
1818
bias: bool = False,
1919
out_channels: int = None,
20-
**kwargs
20+
**act_kwargs
2121
) -> None:
2222
"""MLP token mixer.
2323
@@ -41,13 +41,15 @@ def __init__(
4141
Flag whether to use bias terms in the nn.Linear modules.
4242
out_channels : int, optional
4343
Number of out channels. If None `out_channels = in_channels`
44+
**act_kwargs:
45+
Arbitrary key-word arguments for the activation function.
4446
"""
4547
super().__init__()
4648
self.out_channels = in_channels if out_channels is None else out_channels
4749
hidden_channels = int(mlp_ratio * in_channels)
4850

4951
self.fc1 = nn.Linear(in_channels, hidden_channels, bias=bias)
50-
self.act = Activation(activation)
52+
self.act = Activation(activation, **act_kwargs)
5153
self.drop1 = nn.Dropout(dropout)
5254
self.fc2 = nn.Linear(hidden_channels, self.out_channels, bias=bias)
5355
self.drop2 = nn.Dropout(dropout)
@@ -69,6 +71,7 @@ def __init__(
6971
in_channels: int,
7072
mlp_ratio: int = 4,
7173
activation: str = "star_relu",
74+
activation_kwargs: Dict[str, Any] = None,
7275
dropout: float = 0.0,
7376
bias: bool = False,
7477
normalization: str = "ln",
@@ -105,6 +108,7 @@ def __init__(
105108
activation=activation,
106109
dropout=dropout,
107110
bias=bias,
111+
**activation_kwargs
108112
)
109113

110114
def forward(self, x: torch.Tensor) -> torch.Tensor:

cellseg_models_pytorch/modules/transformers.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
block_types: Tuple[str, ...] = ("basic", "basic"),
2323
dropouts: Tuple[float, ...] = (0.0, 0.0),
2424
biases: Tuple[bool, ...] = (False, False),
25-
act: str = "geglu",
25+
activation: str = "star_relu",
2626
num_groups: int = 32,
2727
slice_size: int = 4,
2828
mlp_ratio: int = 4,
@@ -55,16 +55,16 @@ def __init__(
5555
Dropout probabilities for the SelfAttention blocks.
5656
biases : bool, default=(True, True)
5757
Include bias terms in the SelfAttention blocks.
58-
act : str, default="geglu"
58+
activation : str, default="star_relu"
5959
The activation function applied at the end of the transformer layer fc.
60-
One of ("geglu", "approximate_gelu").
60+
One of ("geglu", "approximate_gelu", "star_relu").
6161
num_groups : int, default=32
6262
Number of groups in the first group-norm op before the input is
6363
projected to be suitable for self-attention.
6464
slice_size : int, default=4
6565
Slice size for sliced self-attention. This is used only if
6666
`name = "slice"` for a SelfAttentionBlock.
67-
fc_projection_mult : int, default=4
67+
mlp_ratio : int, default=4
6868
Multiplier that defines the out dimension of the final fc projection
6969
layer.
7070
"""
@@ -89,7 +89,7 @@ def __init__(
8989
block_types=block_types,
9090
dropouts=dropouts,
9191
biases=biases,
92-
act=act,
92+
activation=activation,
9393
slice_size=slice_size,
9494
mlp_ratio=mlp_ratio,
9595
)
@@ -166,9 +166,9 @@ def __init__(
166166
asymmetrically two separate embeddings (context and input embeddings).
167167
E.g. passage from transformer encoder to transformer decoder. If this is
168168
set to None, no cross attention is applied.
169-
act : str, default="geglu"
169+
activation : str, default="star_relu"
170170
The activation function applied at the end of the transformer layer fc.
171-
One of ("geglu", "approximate_gelu").
171+
One of ("gelu", "geglu", "approximate_gelu", "star_relu").
172172
n_blocks : int, default=2
173173
Number of SelfAttentionBlocks used in this layer.
174174
block_types : Tuple[str, ...], default=("basic", "basic")
@@ -182,9 +182,11 @@ def __init__(
182182
slice_size : int, default=4
183183
Slice size for sliced self-attention. This is used only if
184184
`name = "slice"` for a SelfAttentionBlock.
185-
fc_projection_mult : int, default=4
185+
mlp_proj : int, default=4
186186
Multiplier that defines the out dimension of the final fc projection
187187
layer.
188+
**kwargs:
189+
Arbitrary key-word arguments (e.g. for activation function.).
188190
189191
Raises
190192
------
@@ -227,6 +229,7 @@ def __init__(
227229
activation=activation,
228230
normalization="ln",
229231
norm_kwargs={"normalized_shape": query_dim},
232+
activation_kwargs=kwargs,
230233
)
231234

232235
def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
@@ -251,7 +254,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
251254
con = None
252255
if i == n_blocks:
253256
con = context
254-
257+
print("context: ", con.shape)
255258
x = tr_block(x, con)
256259

257260
return self.mlp(x) + x
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
## Features
2+
3+
- Add transformers modules
4+
- Add exact, slice, and memory efficient (xformers) self attention modules
5+
- Add transformers modules to `Decoder` modules
6+
- Add common transformer mlp avtivation functions.

0 commit comments

Comments
 (0)