Skip to content

Commit 1fd45ca

Browse files
committed
Fix mask in attention
1 parent c333ea7 commit 1fd45ca

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

dalle_pytorch/attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9090
if exists(cache):
9191
cache[qkv_key] = (k, v)
9292

93-
# mask_value = max_neg_value(q)
9493
dots = q @ k.swapaxes(-1, -2)
94+
mask_value = max_neg_value(dots)
9595

96-
# if exists(mask): # TODO:
97-
# mask = rearrange(mask, 'b j -> b () () j')
98-
# dots.masked_fill_(~mask, mask_value)
99-
# del mask
96+
if exists(mask):
97+
mask = rearrange(mask, 'b j -> b () () j')
98+
dots.masked_fill_(~mask, mask_value)
99+
del mask
100100

101101
# if self.causal: # TODO:
102102
# i, j = dots.shape[-2:]

0 commit comments

Comments
 (0)