1
1
from collections .abc import Iterable
2
2
from functools import partial
3
- from itertools import islice , cycle
3
+ from itertools import islice , cycle , product
4
4
5
5
import torch
6
6
from torch import nn , einsum
@@ -161,11 +161,15 @@ def __init__(
161
161
rotary_emb = True ,
162
162
shared_attn_ids = None ,
163
163
shared_ff_ids = None ,
164
+ use_static_masks = False ,
164
165
):
165
166
super ().__init__ ()
166
167
layers = nn .ModuleList ([])
167
168
sparse_layer = cast_tuple (sparse_attn , depth )
168
169
170
+ self .seq_len = seq_len
171
+ self .image_fmap_size = image_fmap_size
172
+
169
173
attn_types = default (attn_types , ('full' ,))
170
174
attn_types = cast_tuple (attn_types )
171
175
attn_type_layer = islice (cycle (attn_types ), depth )
@@ -182,9 +186,15 @@ def __init__(
182
186
elif attn_type == 'sparse' :
183
187
attn_class = SparseAttention
184
188
elif attn_type == 'axial_row' :
185
- attn_class = partial (SparseAxialCausalAttention , seq_len = seq_len , axis = 0 , image_size = image_fmap_size , stable = stable )
189
+ if use_static_masks :
190
+ attn_class = partial (Attention , stable = stable , static_mask = self ._get_static_mask (attn_type ))
191
+ else :
192
+ attn_class = partial (SparseAxialCausalAttention , seq_len = seq_len , axis = 0 , image_size = image_fmap_size , stable = stable )
186
193
elif attn_type == 'axial_col' :
187
- attn_class = partial (SparseAxialCausalAttention , seq_len = seq_len , axis = 1 , image_size = image_fmap_size , stable = stable )
194
+ if use_static_masks :
195
+ attn_class = partial (Attention , stable = stable , static_mask = self ._get_static_mask (attn_type ))
196
+ else :
197
+ attn_class = partial (SparseAxialCausalAttention , seq_len = seq_len , axis = 1 , image_size = image_fmap_size , stable = stable )
188
198
elif attn_type == 'conv_like' :
189
199
attn_class = partial (SparseConvCausalAttention , seq_len = seq_len , image_size = image_fmap_size , stable = stable )
190
200
elif attn_type == 'mlp' :
@@ -257,3 +267,22 @@ def __init__(
257
267
258
268
def forward (self , x , ** kwargs ):
259
269
return self .layers (x , rotary_pos_emb = self .pos_emb , ** kwargs )
270
+
271
+ def _get_static_mask (self , attn_type ):
272
+ img_seq_len = self .image_fmap_size ** 2
273
+ text_len = self .seq_len - img_seq_len
274
+
275
+ static_mask = torch .ones (self .seq_len , self .seq_len , dtype = torch .bool )
276
+ static_mask [:, :text_len ] = True
277
+ if attn_type == 'axial_row' :
278
+ for row in range (self .image_fmap_size ):
279
+ begin = text_len + row * self .image_fmap_size
280
+ end = text_len + (row + 1 ) * self .image_fmap_size
281
+ static_mask [begin :end , begin :end ] = True
282
+ elif attn_type == 'axial_col' :
283
+ for col in range (self .image_fmap_size ):
284
+ begin = text_len + col
285
+ static_mask [begin ::self .image_fmap_size , begin ::self .image_fmap_size ] = True
286
+ else :
287
+ raise ValueError (f'attention type "{ attn_type } " can\' t be simulated with a static mask' )
288
+ return static_mask
0 commit comments