Skip to content

Commit b76b78e

Browse files
committed
Use static masks to simulate axial attn
1 parent 059fe1b commit b76b78e

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

dalle_pytorch/attention.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def apply_pos_emb(pos_emb, qkv):
4646
# classes
4747

4848
class Attention(nn.Module):
49-
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
49+
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False,
50+
static_mask = None):
5051
super().__init__()
5152
inner_dim = dim_head * heads
5253
self.heads = heads
@@ -55,6 +56,7 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
5556

5657
self.stable = stable
5758
self.causal = causal
59+
self.register_buffer('static_mask', static_mask, persistent=False)
5860

5961
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
6062
self.to_out = nn.Sequential(
@@ -95,6 +97,9 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9597
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
9698
dots.masked_fill_(mask, mask_value)
9799

100+
if exists(self.static_mask):
101+
dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value)
102+
98103
attn = softmax(dots, dim=-1)
99104

100105
out = attn @ v
@@ -126,7 +131,13 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
126131
nn.Dropout(dropout)
127132
)
128133

129-
def forward(self, x, mask = None, rotary_pos_emb = None):
134+
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
135+
n0 = x.shape[1]
136+
if exists(cache):
137+
if cache_key in cache:
138+
x = torch.cat([cache[cache_key], x], dim=-2)
139+
cache[cache_key] = x
140+
130141
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
131142
softmax = torch.softmax if not self.stable else stable_softmax
132143

@@ -221,7 +232,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
221232

222233
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
223234
out = self.to_out(out)
224-
return out[:, :n]
235+
return out[:, n - n0:n]
225236

226237
# sparse axial causal attention
227238

dalle_pytorch/dalle_pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def __init__(
344344
shared_attn_ids = None,
345345
shared_ff_ids = None,
346346
share_input_output_emb = False,
347+
use_static_masks = False,
347348
):
348349
super().__init__()
349350
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
@@ -391,6 +392,7 @@ def __init__(
391392
rotary_emb = rotary_emb,
392393
shared_attn_ids = shared_attn_ids,
393394
shared_ff_ids = shared_ff_ids,
395+
use_static_masks = use_static_masks,
394396
)
395397

396398
self.stable = stable

dalle_pytorch/transformer.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Iterable
22
from functools import partial
3-
from itertools import islice, cycle
3+
from itertools import islice, cycle, product
44

55
import torch
66
from torch import nn, einsum
@@ -161,11 +161,15 @@ def __init__(
161161
rotary_emb = True,
162162
shared_attn_ids = None,
163163
shared_ff_ids = None,
164+
use_static_masks = False,
164165
):
165166
super().__init__()
166167
layers = nn.ModuleList([])
167168
sparse_layer = cast_tuple(sparse_attn, depth)
168169

170+
self.seq_len = seq_len
171+
self.image_fmap_size = image_fmap_size
172+
169173
attn_types = default(attn_types, ('full',))
170174
attn_types = cast_tuple(attn_types)
171175
attn_type_layer = islice(cycle(attn_types), depth)
@@ -182,9 +186,15 @@ def __init__(
182186
elif attn_type == 'sparse':
183187
attn_class = SparseAttention
184188
elif attn_type == 'axial_row':
185-
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
189+
if use_static_masks:
190+
attn_class = partial(Attention, stable = stable, static_mask = self._get_static_mask(attn_type))
191+
else:
192+
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
186193
elif attn_type == 'axial_col':
187-
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
194+
if use_static_masks:
195+
attn_class = partial(Attention, stable = stable, static_mask = self._get_static_mask(attn_type))
196+
else:
197+
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
188198
elif attn_type == 'conv_like':
189199
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
190200
elif attn_type == 'mlp':
@@ -257,3 +267,22 @@ def __init__(
257267

258268
def forward(self, x, **kwargs):
259269
return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)
270+
271+
def _get_static_mask(self, attn_type):
272+
img_seq_len = self.image_fmap_size ** 2
273+
text_len = self.seq_len - img_seq_len
274+
275+
static_mask = torch.ones(self.seq_len, self.seq_len, dtype=torch.bool)
276+
static_mask[:, :text_len] = True
277+
if attn_type == 'axial_row':
278+
for row in range(self.image_fmap_size):
279+
begin = text_len + row * self.image_fmap_size
280+
end = text_len + (row + 1) * self.image_fmap_size
281+
static_mask[begin:end, begin:end] = True
282+
elif attn_type == 'axial_col':
283+
for col in range(self.image_fmap_size):
284+
begin = text_len + col
285+
static_mask[begin::self.image_fmap_size, begin::self.image_fmap_size] = True
286+
else:
287+
raise ValueError(f'attention type "{attn_type}" can\'t be simulated with a static mask')
288+
return static_mask

0 commit comments

Comments
 (0)