@@ -22,7 +22,7 @@ def __init__(
22
22
block_types : Tuple [str , ...] = ("basic" , "basic" ),
23
23
dropouts : Tuple [float , ...] = (0.0 , 0.0 ),
24
24
biases : Tuple [bool , ...] = (False , False ),
25
- act : str = "geglu " ,
25
+ activation : str = "star_relu " ,
26
26
num_groups : int = 32 ,
27
27
slice_size : int = 4 ,
28
28
mlp_ratio : int = 4 ,
@@ -55,16 +55,16 @@ def __init__(
55
55
Dropout probabilities for the SelfAttention blocks.
56
56
biases : bool, default=(True, True)
57
57
Include bias terms in the SelfAttention blocks.
58
- act : str, default="geglu "
58
+ activation : str, default="star_relu "
59
59
The activation function applied at the end of the transformer layer fc.
60
- One of ("geglu", "approximate_gelu").
60
+ One of ("geglu", "approximate_gelu", "star_relu" ).
61
61
num_groups : int, default=32
62
62
Number of groups in the first group-norm op before the input is
63
63
projected to be suitable for self-attention.
64
64
slice_size : int, default=4
65
65
Slice size for sliced self-attention. This is used only if
66
66
`name = "slice"` for a SelfAttentionBlock.
67
- fc_projection_mult : int, default=4
67
+ mlp_ratio : int, default=4
68
68
Multiplier that defines the out dimension of the final fc projection
69
69
layer.
70
70
"""
@@ -89,7 +89,7 @@ def __init__(
89
89
block_types = block_types ,
90
90
dropouts = dropouts ,
91
91
biases = biases ,
92
- act = act ,
92
+ activation = activation ,
93
93
slice_size = slice_size ,
94
94
mlp_ratio = mlp_ratio ,
95
95
)
@@ -166,9 +166,9 @@ def __init__(
166
166
asymmetrically two separate embeddings (context and input embeddings).
167
167
E.g. passage from transformer encoder to transformer decoder. If this is
168
168
set to None, no cross attention is applied.
169
- act : str, default="geglu "
169
+ activation : str, default="star_relu "
170
170
The activation function applied at the end of the transformer layer fc.
171
- One of ("geglu", "approximate_gelu").
171
+ One of ("gelu", " geglu", "approximate_gelu", "star_relu ").
172
172
n_blocks : int, default=2
173
173
Number of SelfAttentionBlocks used in this layer.
174
174
block_types : Tuple[str, ...], default=("basic", "basic")
@@ -182,9 +182,11 @@ def __init__(
182
182
slice_size : int, default=4
183
183
Slice size for sliced self-attention. This is used only if
184
184
`name = "slice"` for a SelfAttentionBlock.
185
- fc_projection_mult : int, default=4
185
+ mlp_proj : int, default=4
186
186
Multiplier that defines the out dimension of the final fc projection
187
187
layer.
188
+ **kwargs:
189
+ Arbitrary key-word arguments (e.g. for activation function.).
188
190
189
191
Raises
190
192
------
@@ -227,6 +229,7 @@ def __init__(
227
229
activation = activation ,
228
230
normalization = "ln" ,
229
231
norm_kwargs = {"normalized_shape" : query_dim },
232
+ activation_kwargs = kwargs ,
230
233
)
231
234
232
235
def forward (self , x : torch .Tensor , context : torch .Tensor = None ) -> torch .Tensor :
@@ -251,7 +254,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
251
254
con = None
252
255
if i == n_blocks :
253
256
con = context
254
-
257
+ print ( "context: " , con . shape )
255
258
x = tr_block (x , con )
256
259
257
260
return self .mlp (x ) + x
0 commit comments